Compare commits

...

1 Commits

Author SHA1 Message Date
Dhruv Nair
9262dab7e7 update 2023-11-09 12:32:13 +00:00
2 changed files with 21 additions and 13 deletions

View File

@@ -309,6 +309,17 @@ class LoraLoaderMixinTests(unittest.TestCase):
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
for k, v in attn_processors.items()
}
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
@unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda") @unittest.skipIf(not torch.cuda.is_available(), reason="xformers requires cuda")
def test_stable_diffusion_attn_processors(self): def test_stable_diffusion_attn_processors(self):
# disable_full_determinism() # disable_full_determinism()
@@ -341,17 +352,6 @@ class LoraLoaderMixinTests(unittest.TestCase):
image = sd_pipe(**inputs).images image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3) assert image.shape == (1, 64, 64, 3)
# run lora xformers attention
attn_processors, _ = create_unet_lora_layers(sd_pipe.unet)
attn_processors = {
k: LoRAXFormersAttnProcessor(hidden_size=v.hidden_size, cross_attention_dim=v.cross_attention_dim)
for k, v in attn_processors.items()
}
attn_processors = {k: v.to("cuda") for k, v in attn_processors.items()}
sd_pipe.unet.set_attn_processor(attn_processors)
image = sd_pipe(**inputs).images
assert image.shape == (1, 64, 64, 3)
# enable_full_determinism() # enable_full_determinism()
def test_stable_diffusion_lora(self): def test_stable_diffusion_lora(self):
@@ -605,7 +605,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
orig_image_slice, orig_image_slice_two, atol=1e-3 orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters." ), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="This test is supposed to run on GPU with xformers",
)
def test_lora_unet_attn_processors_with_xformers(self): def test_lora_unet_attn_processors_with_xformers(self):
with tempfile.TemporaryDirectory() as tmpdirname: with tempfile.TemporaryDirectory() as tmpdirname:
self.create_lora_weight_file(tmpdirname) self.create_lora_weight_file(tmpdirname)
@@ -642,7 +645,10 @@ class LoraLoaderMixinTests(unittest.TestCase):
if isinstance(module, Attention): if isinstance(module, Attention):
self.assertIsInstance(module.processor, XFormersAttnProcessor) self.assertIsInstance(module.processor, XFormersAttnProcessor)
@unittest.skipIf(torch_device != "cuda", "This test is supposed to run on GPU") @unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="This test is supposed to run on GPU with xformers",
)
def test_lora_save_load_with_xformers(self): def test_lora_save_load_with_xformers(self):
pipeline_components, lora_components = self.get_dummy_components() pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionPipeline(**pipeline_components) sd_pipe = StableDiffusionPipeline(**pipeline_components)

View File

@@ -975,6 +975,7 @@ class PeftLoraLoaderMixinTests:
_ = pipe(**inputs, generator=torch.manual_seed(0)).images _ = pipe(**inputs, generator=torch.manual_seed(0)).images
@require_peft_backend
class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = StableDiffusionPipeline pipeline_class = StableDiffusionPipeline
scheduler_cls = DDIMScheduler scheduler_cls = DDIMScheduler
@@ -1197,6 +1198,7 @@ class StableDiffusionLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3)) self.assertTrue(np.allclose(expected_slice_scale, predicted_slice, atol=1e-3, rtol=1e-3))
@require_peft_backend
class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase): class StableDiffusionXLLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
has_two_text_encoders = True has_two_text_encoders = True
pipeline_class = StableDiffusionXLPipeline pipeline_class = StableDiffusionXLPipeline