mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-11 23:14:37 +08:00
Compare commits
5 Commits
pipeline-s
...
remove-lor
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
84ee2fd6b7 | ||
|
|
5667f6020c | ||
|
|
3470e976e1 | ||
|
|
a5eaf4a699 | ||
|
|
37617f9db8 |
@@ -1487,7 +1487,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
Load LoRA layers into [`FluxTransformer2DModel`],
|
||||
[`CLIPTextModel`](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel).
|
||||
|
||||
Specific to [`StableDiffusion3Pipeline`].
|
||||
Specific to [`FluxPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer", "text_encoder"]
|
||||
@@ -1628,30 +1628,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is
|
||||
loaded.
|
||||
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
|
||||
dict is loaded into `self.transformer`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
||||
`default_{i}` where i is the total number of adapters being loaded.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
`Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
@@ -3651,44 +3628,17 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.lora_state_dict
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Return state dict for lora weights and the network alphas.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* of a pretrained model hosted on the Hub.
|
||||
- A path to a *directory* containing the model weights.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository.
|
||||
weight_name (`str`, *optional*, defaults to None):
|
||||
Name of the serialized state dict file.
|
||||
use_safetensors (`bool`, *optional*):
|
||||
Whether to use safetensors for loading.
|
||||
return_lora_metadata (`bool`, *optional*, defaults to False):
|
||||
When enabled, additionally return the LoRA adapter metadata.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -3731,6 +3681,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
@@ -3739,26 +3690,13 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer`
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
hotswap (`bool`, *optional*):
|
||||
Whether to substitute an existing (LoRA) adapter with the newly loaded adapter in-place.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
|
||||
weights.
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.lora_state_dict`].
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
@@ -3775,7 +3713,6 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
# Load LoRA into transformer
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
@@ -3787,6 +3724,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
@@ -3798,23 +3736,9 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
Load the LoRA layers specified in `state_dict` into `transformer`.
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters.
|
||||
transformer (`Kandinsky5Transformer3DModel`):
|
||||
The transformer model to load the LoRA layers into.
|
||||
adapter_name (`str`, *optional*):
|
||||
Adapter name to be used for referencing the loaded adapter model.
|
||||
low_cpu_mem_usage (`bool`, *optional*):
|
||||
Speed up model loading by only loading the pretrained LoRA weights.
|
||||
hotswap (`bool`, *optional*):
|
||||
See [`~loaders.KandinskyLoraLoaderMixin.load_lora_weights`].
|
||||
metadata (`dict`):
|
||||
Optional LoRA adapter metadata.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and not is_peft_version(">=", "0.13.1"):
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
@@ -3832,6 +3756,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
@@ -3840,24 +3765,10 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata=None,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the transformer and text encoders.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to.
|
||||
transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `transformer`.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way.
|
||||
transformer_lora_adapter_metadata:
|
||||
LoRA adapter metadata associated with the transformer.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
@@ -3867,7 +3778,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers`")
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
@@ -3879,6 +3790,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
@@ -3888,25 +3800,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing.
|
||||
|
||||
Example:
|
||||
```py
|
||||
from diffusers import Kandinsky5T2VPipeline
|
||||
|
||||
pipeline = Kandinsky5T2VPipeline.from_pretrained("ai-forever/Kandinsky-5.0-T2V")
|
||||
pipeline.load_lora_weights("path/to/lora.safetensors")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
@@ -3916,12 +3810,10 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of [`pipe.fuse_lora()`].
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user