Compare commits

...

5 Commits

Author SHA1 Message Date
Sayak Paul
c7a7436584 Merge branch 'main' into better-copy-lora-pipelines 2025-03-08 19:20:34 +05:30
Sayak Paul
fc1e246424 Merge branch 'main' into better-copy-lora-pipelines 2025-03-07 12:50:19 +05:30
sayakpaul
790fcbb195 better 2025-03-07 08:14:00 +05:30
sayakpaul
146db2c231 better 2025-03-07 08:11:02 +05:30
sayakpaul
5eb1f07a75 more sanity of mind with copied from ... 2025-03-07 08:04:28 +05:30

View File

@@ -843,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers): if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError( raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`." "You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
) )
if unet_lora_layers: if unet_lora_layers:
state_dict.update(cls.pack_weights(unet_lora_layers, "unet")) state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
if text_encoder_lora_layers: if text_encoder_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder")) state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
@@ -1210,10 +1210,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
) )
@classmethod @classmethod
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
def save_lora_weights( def save_lora_weights(
cls, cls,
save_directory: Union[str, os.PathLike], save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, torch.nn.Module] = None, transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True, is_main_process: bool = True,
@@ -1262,7 +1263,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
if text_encoder_2_lora_layers: if text_encoder_2_lora_layers:
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
# Save the model
cls.write_lora_layers( cls.write_lora_layers(
state_dict=state_dict, state_dict=state_dict,
save_directory=save_directory, save_directory=save_directory,
@@ -1272,6 +1272,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
@@ -1315,6 +1316,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -1328,7 +1330,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
Args: Args:
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from. components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters. unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
unfuse_text_encoder (`bool`, defaults to `True`): unfuse_text_encoder (`bool`, defaults to `True`):
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
LoRA parameters then it won't have any effect. LoRA parameters then it won't have any effect.
@@ -2833,6 +2835,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
@@ -2876,6 +2879,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -3136,6 +3140,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
@@ -3179,6 +3184,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -3439,6 +3445,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
@@ -3482,6 +3489,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of
@@ -3745,6 +3753,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
safe_serialization=safe_serialization, safe_serialization=safe_serialization,
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora( def fuse_lora(
self, self,
components: List[str] = ["transformer"], components: List[str] = ["transformer"],
@@ -3788,6 +3797,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
) )
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs): def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r""" r"""
Reverses the effect of Reverses the effect of