|
|
|
|
@@ -129,6 +129,46 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
|
|
|
|
|
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
|
|
|
|
|
|
|
|
|
|
def setUp(self):
|
|
|
|
|
self._cache_base_pipeline_output()
|
|
|
|
|
super().setUp()
|
|
|
|
|
|
|
|
|
|
def _cache_base_pipeline_output(self):
|
|
|
|
|
# Get or create the cache on the class (not instance)
|
|
|
|
|
if not hasattr(type(self), "cached_base_pipe_outs"):
|
|
|
|
|
setattr(type(self), "cached_base_pipe_outs", {})
|
|
|
|
|
|
|
|
|
|
cached_base_pipe_outs = type(self).cached_base_pipe_outs
|
|
|
|
|
|
|
|
|
|
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
|
|
|
|
|
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
for scheduler_cls in self.scheduler_classes:
|
|
|
|
|
if scheduler_cls.__name__ in cached_base_pipe_outs:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
components, _, _ = self.get_dummy_components(scheduler_cls)
|
|
|
|
|
pipe = self.pipeline_class(**components)
|
|
|
|
|
pipe = pipe.to(torch_device)
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
|
|
|
|
|
# explicitly.
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
|
|
|
|
|
|
|
|
|
|
# Update the class attribute
|
|
|
|
|
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
|
|
|
|
|
|
|
|
|
|
def get_base_pipeline_output(self, scheduler_cls):
|
|
|
|
|
"""
|
|
|
|
|
Returns the cached base pipeline output for the given scheduler.
|
|
|
|
|
Cache is populated during setUp, so this just retrieves the value.
|
|
|
|
|
"""
|
|
|
|
|
return self.cached_base_pipe_outs[scheduler_cls.__name__]
|
|
|
|
|
|
|
|
|
|
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
|
|
|
|
|
if self.unet_kwargs and self.transformer_kwargs:
|
|
|
|
|
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
|
|
|
|
|
@@ -320,13 +360,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
Tests a simple inference and makes sure it works as expected
|
|
|
|
|
"""
|
|
|
|
|
for scheduler_cls in self.scheduler_classes:
|
|
|
|
|
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
|
|
|
|
|
pipe = self.pipeline_class(**components)
|
|
|
|
|
pipe = pipe.to(torch_device)
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs()
|
|
|
|
|
output_no_lora = pipe(**inputs)[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
def test_simple_inference_with_text_lora(self):
|
|
|
|
|
@@ -341,7 +375,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -424,7 +458,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -480,7 +514,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -518,7 +552,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -550,7 +584,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -585,7 +619,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -636,7 +670,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -687,7 +721,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
|
|
|
|
|
@@ -730,7 +764,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -771,7 +805,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -815,7 +849,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -853,7 +887,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -932,7 +966,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
|
|
|
|
@@ -1061,7 +1095,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
|
|
|
|
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
|
|
|
|
|
@@ -1118,7 +1152,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
|
|
|
|
@@ -1281,7 +1315,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
|
|
|
|
@@ -1375,7 +1409,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
|
|
|
|
|
@@ -1619,7 +1653,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
@@ -1700,7 +1734,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
@@ -1755,7 +1789,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_dora_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -1887,7 +1921,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
original_out = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
|
|
|
|
|
logger = logging.get_logger("diffusers.loaders.peft")
|
|
|
|
|
@@ -1933,7 +1967,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
|
|
|
|
|
@@ -2287,7 +2321,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe = self.pipeline_class(**components).to(torch_device)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
self.assertTrue(output_no_lora.shape == self.output_shape)
|
|
|
|
|
|
|
|
|
|
pipe, _ = self.add_adapters_to_pipeline(
|
|
|
|
|
@@ -2337,7 +2371,7 @@ class PeftLoraLoaderMixinTests:
|
|
|
|
|
pipe.set_progress_bar_config(disable=None)
|
|
|
|
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
|
|
|
|
|
|
|
|
|
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
|
|
|
|
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
|
|
|
|
|
|
|
|
|
|
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
|
|
|
|
|
pipe.text_encoder.add_adapter(text_lora_config)
|
|
|
|
|
|