mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 16:34:27 +08:00
Compare commits
5 Commits
custom-blo
...
better-cop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7a7436584 | ||
|
|
fc1e246424 | ||
|
|
790fcbb195 | ||
|
|
146db2c231 | ||
|
|
5eb1f07a75 |
@@ -843,11 +843,11 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers`, `text_encoder_2_lora_layers`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if unet_lora_layers:
|
if unet_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(unet_lora_layers, "unet"))
|
state_dict.update(cls.pack_weights(unet_lora_layers, cls.unet_name))
|
||||||
|
|
||||||
if text_encoder_lora_layers:
|
if text_encoder_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
state_dict.update(cls.pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||||
@@ -1210,10 +1210,11 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.save_lora_weights with unet->transformer
|
||||||
def save_lora_weights(
|
def save_lora_weights(
|
||||||
cls,
|
cls,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
transformer_lora_layers: Dict[str, torch.nn.Module] = None,
|
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
is_main_process: bool = True,
|
is_main_process: bool = True,
|
||||||
@@ -1262,7 +1263,6 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
if text_encoder_2_lora_layers:
|
if text_encoder_2_lora_layers:
|
||||||
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
state_dict.update(cls.pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||||
|
|
||||||
# Save the model
|
|
||||||
cls.write_lora_layers(
|
cls.write_lora_layers(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
save_directory=save_directory,
|
save_directory=save_directory,
|
||||||
@@ -1272,6 +1272,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.fuse_lora with unet->transformer
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
components: List[str] = ["transformer", "text_encoder", "text_encoder_2"],
|
||||||
@@ -1315,6 +1316,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin.unfuse_lora with unet->transformer
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer", "text_encoder", "text_encoder_2"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -1328,7 +1330,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||||
LoRA parameters then it won't have any effect.
|
LoRA parameters then it won't have any effect.
|
||||||
@@ -2833,6 +2835,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer"],
|
components: List[str] = ["transformer"],
|
||||||
@@ -2876,6 +2879,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -3136,6 +3140,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer"],
|
components: List[str] = ["transformer"],
|
||||||
@@ -3179,6 +3184,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -3439,6 +3445,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer"],
|
components: List[str] = ["transformer"],
|
||||||
@@ -3482,6 +3489,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
@@ -3745,6 +3753,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|||||||
safe_serialization=safe_serialization,
|
safe_serialization=safe_serialization,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||||
def fuse_lora(
|
def fuse_lora(
|
||||||
self,
|
self,
|
||||||
components: List[str] = ["transformer"],
|
components: List[str] = ["transformer"],
|
||||||
@@ -3788,6 +3797,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
|||||||
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
components=components, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Reverses the effect of
|
Reverses the effect of
|
||||||
|
|||||||
Reference in New Issue
Block a user