Compare commits

...

15 Commits

Author SHA1 Message Date
DN6
e1f502fac1 update 2025-09-11 08:52:01 +05:30
DN6
9150ab02f6 update 2025-09-11 08:21:57 +05:30
DN6
93f71d95a2 update 2025-09-11 08:20:00 +05:30
DN6
de4ba0a977 update 2025-09-11 08:14:57 +05:30
DN6
40f12d2aea update 2025-09-11 08:07:04 +05:30
DN6
1e0856616a update 2025-09-10 18:12:00 +05:30
DN6
fa926e78f5 update 2025-09-10 17:40:01 +05:30
DN6
a8c5801e26 update 2025-09-10 17:32:44 +05:30
DN6
2743c9ee3b update 2025-09-10 17:25:58 +05:30
sayakpaul
2c47a2ffd4 Revert "up"
This reverts commit 772c32e433.
2025-09-10 11:24:07 +05:30
sayakpaul
772c32e433 up 2025-09-10 10:34:04 +05:30
sayakpaul
4256de9fea up 2025-09-08 14:09:21 +05:30
sayakpaul
6c0c72de9c up 2025-09-08 12:15:01 +05:30
sayakpaul
c8afd1c8b4 up 2025-09-08 12:05:53 +05:30
sayakpaul
02fd92e38e cache non lora pipeline outputs. 2025-09-08 11:39:11 +05:30
3 changed files with 74 additions and 39 deletions

View File

@@ -53,7 +53,7 @@ from .utils import PeftLoraLoaderMixinTests, check_if_lora_correctly_set # noqa
@require_peft_backend
class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = FluxPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
@@ -123,7 +123,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
output_no_lora = self.get_base_pipeline_output(FlowMatchEulerDiscreteScheduler)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe.transformer.add_adapter(denoiser_lora_config)
@@ -171,7 +171,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
output_no_lora = self.get_base_pipeline_output(FlowMatchEulerDiscreteScheduler)
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the second LoRA we will load.
@@ -220,7 +220,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0)).images
output_no_lora = self.get_base_pipeline_output(FlowMatchEulerDiscreteScheduler)
self.assertTrue(output_no_lora.shape == self.output_shape)
# Modify the config to have a layer which won't be present in the first LoRA we will load.
@@ -280,7 +280,7 @@ class FluxLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pass
class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class FluxControlLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = FluxControlPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler()
scheduler_kwargs = {}
@@ -331,6 +331,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
np.random.seed(0)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"control_image": Image.fromarray(np.random.randint(0, 255, size=(32, 32, 3), dtype="uint8")),
@@ -356,7 +357,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
logger.setLevel(logging.INFO)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_output = self.get_base_pipeline_output(FlowMatchEulerDiscreteScheduler)
for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]:
norm_state_dict = {}
@@ -642,7 +643,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_output = self.get_base_pipeline_output(FlowMatchEulerDiscreteScheduler)
out_features, in_features = pipe.transformer.x_embedder.weight.shape
rank = 4

View File

@@ -29,7 +29,7 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
class SanaLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
class SanaLoRATests(PeftLoraLoaderMixinTests, unittest.TestCase):
pipeline_class = SanaPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler(shift=7.0)
scheduler_kwargs = {}

View File

@@ -129,6 +129,46 @@ class PeftLoraLoaderMixinTests:
text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
def setUp(self):
self._cache_base_pipeline_output()
super().setUp()
def _cache_base_pipeline_output(self):
# Get or create the cache on the class (not instance)
if not hasattr(type(self), "cached_base_pipe_outs"):
setattr(type(self), "cached_base_pipe_outs", {})
cached_base_pipe_outs = type(self).cached_base_pipe_outs
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
return
for scheduler_cls in self.scheduler_classes:
if scheduler_cls.__name__ in cached_base_pipe_outs:
continue
components, _, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
# Always ensure the inputs are without the `generator`. Make sure to pass the `generator`
# explicitly.
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
# Update the class attribute
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
def get_base_pipeline_output(self, scheduler_cls):
"""
Returns the cached base pipeline output for the given scheduler.
Cache is populated during setUp, so this just retrieves the value.
"""
return self.cached_base_pipe_outs[scheduler_cls.__name__]
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):
if self.unet_kwargs and self.transformer_kwargs:
raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.")
@@ -320,13 +360,7 @@ class PeftLoraLoaderMixinTests:
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in self.scheduler_classes:
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs()
output_no_lora = pipe(**inputs)[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
def test_simple_inference_with_text_lora(self):
@@ -341,7 +375,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -424,7 +458,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -480,7 +514,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -518,7 +552,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -550,7 +584,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -585,7 +619,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -636,7 +670,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -687,7 +721,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None)
@@ -730,7 +764,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -771,7 +805,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -815,7 +849,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -853,7 +887,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -932,7 +966,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1061,7 +1095,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder")
@@ -1118,7 +1152,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1281,7 +1315,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1375,7 +1409,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config, "adapter-1")
@@ -1619,7 +1653,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1700,7 +1734,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -1755,7 +1789,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_dora_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_dora_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -1887,7 +1921,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
original_out = self.get_base_pipeline_output(scheduler_cls)
no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)}
logger = logging.get_logger("diffusers.loaders.peft")
@@ -1933,7 +1967,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config)
@@ -2287,7 +2321,7 @@ class PeftLoraLoaderMixinTests:
pipe = self.pipeline_class(**components).to(torch_device)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
self.assertTrue(output_no_lora.shape == self.output_shape)
pipe, _ = self.add_adapters_to_pipeline(
@@ -2337,7 +2371,7 @@ class PeftLoraLoaderMixinTests:
pipe.set_progress_bar_config(disable=None)
_, _, inputs = self.get_dummy_inputs(with_generator=False)
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
output_no_lora = self.get_base_pipeline_output(scheduler_cls)
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
pipe.text_encoder.add_adapter(text_lora_config)