Compare commits

...

3 Commits

Author SHA1 Message Date
Pedro Cuenca
e808fd1677 Skip test_load_pipeline_from_git on mps.
Not compatible with float16.
2022-11-07 12:43:47 +01:00
Pedro Cuenca
995b865b72 Merge remote-tracking branch 'origin/main' into fix-mps-crash 2022-11-07 12:28:25 +01:00
Pedro Cuenca
c44f1b9ac9 Tests: fix mps crashes. 2022-11-07 11:33:58 +01:00
2 changed files with 5 additions and 2 deletions

View File

@@ -456,6 +456,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice): def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4") model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed) latents = self.get_latents(seed)
@@ -507,6 +508,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice): def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5") model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed) latents = self.get_latents(seed)
@@ -558,6 +560,7 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
# fmt: on # fmt: on
] ]
) )
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting") model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64)) latents = self.get_latents(seed, shape=(4, 9, 64, 64))

View File

@@ -41,7 +41,7 @@ from diffusers import (
from diffusers.pipeline_utils import DiffusionPipeline from diffusers.pipeline_utils import DiffusionPipeline
from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from diffusers.schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device from diffusers.utils import CONFIG_NAME, WEIGHTS_NAME, floats_tensor, slow, torch_device
from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir from diffusers.utils.testing_utils import CaptureLogger, get_tests_dir, require_torch_gpu
from parameterized import parameterized from parameterized import parameterized
from PIL import Image from PIL import Image
from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPTokenizer
@@ -124,7 +124,7 @@ class CustomPipelineTests(unittest.TestCase):
assert output_str == "This is a local test" assert output_str == "This is a local test"
@slow @slow
@unittest.skipIf(torch_device == "cpu", "Stable diffusion is supposed to run on GPU") @require_torch_gpu
def test_load_pipeline_from_git(self): def test_load_pipeline_from_git(self):
clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K" clip_model_id = "laion/CLIP-ViT-B-32-laion2B-s34B-b79K"