Compare commits

...

3 Commits

Author SHA1 Message Date
Pedro Cuenca
e808fd1677 Skip test_load_pipeline_from_git on mps.
Not compatible with float16.
2022-11-07 12:43:47 +01:00
Pedro Cuenca
995b865b72 Merge remote-tracking branch 'origin/main' into fix-mps-crash 2022-11-07 12:28:25 +01:00
Pedro Cuenca
c44f1b9ac9 Tests: fix mps crashes. 2022-11-07 11:33:58 +01:00
2 changed files with 5 additions and 2 deletions

View File

@@ -456,6 +456,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
@@ -507,6 +508,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
@@ -558,6 +560,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))

View File

@@ -41,7 +41,7 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized
from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -124,7 +124,7 @@ class CustomPipelineTests(unittest.TestCase):
assert output_str == "This is a local test"
@slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU")
@require_torch_gpu
def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"