Compare commits

...

5 Commits

Author SHA1 Message Date
Sayak Paul
c7a7436584 Merge branch 'main' into better-copy-lora-pipelines 2025-03-08 19:20:34 +05:30
Sayak Paul
fc1e246424 Merge branch 'main' into better-copy-lora-pipelines 2025-03-07 12:50:19 +05:30
sayakpaul
790fcbb195 better 2025-03-07 08:14:00 +05:30
sayakpaul
146db2c231 better 2025-03-07 08:11:02 +05:30
sayakpaul
5eb1f07a75 more sanity of mind with copied from ... 2025-03-07 08:04:28 +05:30

View File

@@ -843,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -1210,10 +1210,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
@@ -1262,7 +1263,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model
cls.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
@@ -1272,6 +1272,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora(
self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1315,6 +1316,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r"""
Reverses the effect of
@@ -1328,7 +1330,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect.
@@ -2833,6 +2835,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -2876,6 +2879,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
@@ -3136,6 +3140,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -3179,6 +3184,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
@@ -3439,6 +3445,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -3482,6 +3489,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of
@@ -3745,6 +3753,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
@@ -3788,6 +3797,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
Reverses the effect of