mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
31 Commits
custom-blo
...
folderize-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d8621ec72 | ||
|
|
0c37895440 | ||
|
|
9bebdf225d | ||
|
|
c05114d5ec | ||
|
|
a57a5ab4c0 | ||
|
|
4b1c7dc81a | ||
|
|
1590325a60 | ||
|
|
e4dd7c5333 | ||
|
|
d6430c79a3 | ||
|
|
1597ae6ac9 | ||
|
|
11a23d11fe | ||
|
|
6b8b225aca | ||
|
|
27d2401e59 | ||
|
|
1ddfe14220 | ||
|
|
0e8d1d25eb | ||
|
|
546446ae21 | ||
|
|
ea3f0b8d68 | ||
|
|
f0ea9ff2e2 | ||
|
|
1b7c286974 | ||
|
|
6138cc1720 | ||
|
|
ea0ce4bfab | ||
|
|
f2aa2f91dc | ||
|
|
4faac73219 | ||
|
|
d870e3c9a6 | ||
|
|
178b884673 | ||
|
|
2da3cb4a8c | ||
|
|
ea3ba4f431 | ||
|
|
21b2566933 | ||
|
|
a71334b861 | ||
|
|
eb47a67d50 | ||
|
|
8267677a24 |
@@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
|||||||
|
|
||||||
## IPAdapterMixin
|
## IPAdapterMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin
|
||||||
|
|
||||||
## SD3IPAdapterMixin
|
## SD3IPAdapterMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin
|
||||||
- all
|
- all
|
||||||
- is_ip_adapter_active
|
- is_ip_adapter_active
|
||||||
|
|
||||||
|
|||||||
@@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
|||||||
|
|
||||||
## StableDiffusionLoraLoaderMixin
|
## StableDiffusionLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||||
|
|
||||||
## StableDiffusionXLLoraLoaderMixin
|
## StableDiffusionXLLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
||||||
|
|
||||||
## SD3LoraLoaderMixin
|
## SD3LoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin
|
||||||
|
|
||||||
## FluxLoraLoaderMixin
|
## FluxLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin
|
||||||
|
|
||||||
## CogVideoXLoraLoaderMixin
|
## CogVideoXLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||||
|
|
||||||
## Mochi1LoraLoaderMixin
|
## Mochi1LoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin
|
||||||
## AuraFlowLoraLoaderMixin
|
## AuraFlowLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||||
|
|
||||||
## LTXVideoLoraLoaderMixin
|
## LTXVideoLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin
|
||||||
|
|
||||||
## SanaLoraLoaderMixin
|
## SanaLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin
|
||||||
|
|
||||||
## HunyuanVideoLoraLoaderMixin
|
## HunyuanVideoLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||||
|
|
||||||
## Lumina2LoraLoaderMixin
|
## Lumina2LoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin
|
||||||
|
|
||||||
## CogView4LoraLoaderMixin
|
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
|
|
||||||
|
|
||||||
## WanLoraLoaderMixin
|
## WanLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||||
|
|
||||||
|
## CogView4LoraLoaderMixin
|
||||||
|
|
||||||
|
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||||
|
|
||||||
|
## CogView4LoraLoaderMixin
|
||||||
|
|
||||||
|
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||||
|
|
||||||
|
## WanLoraLoaderMixin
|
||||||
|
|
||||||
|
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||||
|
|
||||||
## AmusedLoraLoaderMixin
|
## AmusedLoraLoaderMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin
|
||||||
|
|
||||||
## HiDreamImageLoraLoaderMixin
|
## HiDreamImageLoraLoaderMixin
|
||||||
|
|
||||||
@@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
|||||||
|
|
||||||
## LoraBaseMixin
|
## LoraBaseMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
[[autodoc]] loaders.lora.lora_base.LoraBaseMixin
|
||||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
|||||||
|
|
||||||
# SD3Transformer2D
|
# SD3Transformer2D
|
||||||
|
|
||||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||||
|
|
||||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||||
|
|
||||||
@@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
|||||||
|
|
||||||
## SD3Transformer2DLoadersMixin
|
## SD3Transformer2DLoadersMixin
|
||||||
|
|
||||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||||
- all
|
- all
|
||||||
- _load_ip_adapter_weights
|
- _load_ip_adapter_weights
|
||||||
@@ -54,14 +54,14 @@ if is_transformers_available():
|
|||||||
_import_structure = {}
|
_import_structure = {}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
_import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
_import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
_import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
|
||||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
_import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"]
|
||||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
_import_structure["single_file.single_file"] = ["FromSingleFileMixin"]
|
||||||
_import_structure["lora_pipeline"] = [
|
_import_structure["lora.lora_pipeline"] = [
|
||||||
"AmusedLoraLoaderMixin",
|
"AmusedLoraLoaderMixin",
|
||||||
"StableDiffusionLoraLoaderMixin",
|
"StableDiffusionLoraLoaderMixin",
|
||||||
"SD3LoraLoaderMixin",
|
"SD3LoraLoaderMixin",
|
||||||
@@ -80,7 +80,7 @@ if is_torch_available():
|
|||||||
"HiDreamImageLoraLoaderMixin",
|
"HiDreamImageLoraLoaderMixin",
|
||||||
]
|
]
|
||||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||||
_import_structure["ip_adapter"] = [
|
_import_structure["ip_adapter.ip_adapter"] = [
|
||||||
"IPAdapterMixin",
|
"IPAdapterMixin",
|
||||||
"FluxIPAdapterMixin",
|
"FluxIPAdapterMixin",
|
||||||
"SD3IPAdapterMixin",
|
"SD3IPAdapterMixin",
|
||||||
@@ -91,19 +91,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
|||||||
|
|
||||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .single_file_model import FromOriginalModelMixin
|
from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
|
||||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
from .single_file import FromOriginalModelMixin
|
||||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
|
||||||
from .unet import UNet2DConditionLoadersMixin
|
from .unet import UNet2DConditionLoadersMixin
|
||||||
from .utils import AttnProcsLayers
|
from .utils import AttnProcsLayers
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
from .ip_adapter import (
|
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||||
FluxIPAdapterMixin,
|
from .lora import (
|
||||||
IPAdapterMixin,
|
|
||||||
SD3IPAdapterMixin,
|
|
||||||
)
|
|
||||||
from .lora_pipeline import (
|
|
||||||
AmusedLoraLoaderMixin,
|
AmusedLoraLoaderMixin,
|
||||||
AuraFlowLoraLoaderMixin,
|
AuraFlowLoraLoaderMixin,
|
||||||
CogVideoXLoraLoaderMixin,
|
CogVideoXLoraLoaderMixin,
|
||||||
@@ -111,6 +106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|||||||
FluxLoraLoaderMixin,
|
FluxLoraLoaderMixin,
|
||||||
HiDreamImageLoraLoaderMixin,
|
HiDreamImageLoraLoaderMixin,
|
||||||
HunyuanVideoLoraLoaderMixin,
|
HunyuanVideoLoraLoaderMixin,
|
||||||
|
LoraBaseMixin,
|
||||||
LoraLoaderMixin,
|
LoraLoaderMixin,
|
||||||
LTXVideoLoraLoaderMixin,
|
LTXVideoLoraLoaderMixin,
|
||||||
Lumina2LoraLoaderMixin,
|
Lumina2LoraLoaderMixin,
|
||||||
|
|||||||
@@ -12,868 +12,27 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
from ..utils import deprecate
|
||||||
import torch.nn.functional as F
|
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||||
from huggingface_hub.utils import validate_hf_hub_args
|
|
||||||
from safetensors import safe_open
|
|
||||||
|
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
|
||||||
from ..utils import (
|
|
||||||
USE_PEFT_BACKEND,
|
|
||||||
_get_detailed_type,
|
|
||||||
_get_model_file,
|
|
||||||
_is_valid_type,
|
|
||||||
is_accelerate_available,
|
|
||||||
is_torch_version,
|
|
||||||
is_transformers_available,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
|
||||||
|
|
||||||
|
class IPAdapterMixin(IPAdapterMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import IPAdapterMixin` instead."
|
||||||
|
deprecate("diffusers.loaders.ip_adapter.IPAdapterMixin", "0.36", deprecation_message)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
|
||||||
|
|
||||||
from ..models.attention_processor import (
|
class FluxIPAdapterMixin(FluxIPAdapterMixin):
|
||||||
AttnProcessor,
|
def __init__(self, *args, **kwargs):
|
||||||
AttnProcessor2_0,
|
deprecation_message = "Importing `FluxIPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import FluxIPAdapterMixin` instead."
|
||||||
FluxAttnProcessor2_0,
|
deprecate("diffusers.loaders.ip_adapter.FluxIPAdapterMixin", "0.36", deprecation_message)
|
||||||
FluxIPAdapterJointAttnProcessor2_0,
|
super().__init__(*args, **kwargs)
|
||||||
IPAdapterAttnProcessor,
|
|
||||||
IPAdapterAttnProcessor2_0,
|
|
||||||
IPAdapterXFormersAttnProcessor,
|
|
||||||
JointAttnProcessor2_0,
|
|
||||||
SD3IPAdapterJointAttnProcessor2_0,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
class SD3IPAdapterMixin(SD3IPAdapterMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `SD3IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import SD3IPAdapterMixin` instead."
|
||||||
class IPAdapterMixin:
|
deprecate("diffusers.loaders.ip_adapter.SD3IPAdapterMixin", "0.36", deprecation_message)
|
||||||
"""Mixin for handling IP Adapters."""
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
@validate_hf_hub_args
|
|
||||||
def load_ip_adapter(
|
|
||||||
self,
|
|
||||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
|
||||||
subfolder: Union[str, List[str]],
|
|
||||||
weight_name: Union[str, List[str]],
|
|
||||||
image_encoder_folder: Optional[str] = "image_encoder",
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
|
||||||
Can be either:
|
|
||||||
|
|
||||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
|
||||||
the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
|
||||||
with [`ModelMixin.save_pretrained`].
|
|
||||||
- A [torch state
|
|
||||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
|
||||||
subfolder (`str` or `List[str]`):
|
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
|
||||||
list is passed, it should have the same length as `weight_name`.
|
|
||||||
weight_name (`str` or `List[str]`):
|
|
||||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
|
||||||
`subfolder`.
|
|
||||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
|
||||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
|
||||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
|
||||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
|
||||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
|
||||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
|
||||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
||||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
||||||
is not used.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
|
||||||
won't be downloaded from the Hub.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
||||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
||||||
allowed by Git.
|
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
|
||||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
|
||||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
|
||||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
|
||||||
argument to `True` will raise an error.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# handle the list inputs for multiple IP Adapters
|
|
||||||
if not isinstance(weight_name, list):
|
|
||||||
weight_name = [weight_name]
|
|
||||||
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
|
||||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
|
||||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
|
||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
|
||||||
|
|
||||||
if not isinstance(subfolder, list):
|
|
||||||
subfolder = [subfolder]
|
|
||||||
if len(subfolder) == 1:
|
|
||||||
subfolder = subfolder * len(weight_name)
|
|
||||||
|
|
||||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
|
||||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
|
||||||
|
|
||||||
if len(weight_name) != len(subfolder):
|
|
||||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
|
||||||
|
|
||||||
# Load the main state dict first.
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
user_agent = {
|
|
||||||
"file_type": "attn_procs_weights",
|
|
||||||
"framework": "pytorch",
|
|
||||||
}
|
|
||||||
state_dicts = []
|
|
||||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
|
||||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
|
||||||
):
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weights_name=weight_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
if weight_name.endswith(".safetensors"):
|
|
||||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
|
||||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
||||||
for key in f.keys():
|
|
||||||
if key.startswith("image_proj."):
|
|
||||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
|
||||||
elif key.startswith("ip_adapter."):
|
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
|
||||||
else:
|
|
||||||
state_dict = load_state_dict(model_file)
|
|
||||||
else:
|
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
|
||||||
|
|
||||||
keys = list(state_dict.keys())
|
|
||||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
|
||||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
|
||||||
|
|
||||||
state_dicts.append(state_dict)
|
|
||||||
|
|
||||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
|
||||||
if image_encoder_folder is not None:
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
|
||||||
if image_encoder_folder.count("/") == 0:
|
|
||||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
|
||||||
else:
|
|
||||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
|
||||||
|
|
||||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
subfolder=image_encoder_subfolder,
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
torch_dtype=self.dtype,
|
|
||||||
).to(self.device)
|
|
||||||
self.register_modules(image_encoder=image_encoder)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
|
||||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# create feature extractor if it has not been registered to the pipeline yet
|
|
||||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
|
||||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
|
||||||
default_clip_size = 224
|
|
||||||
clip_image_size = (
|
|
||||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
|
||||||
)
|
|
||||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
|
||||||
self.register_modules(feature_extractor=feature_extractor)
|
|
||||||
|
|
||||||
# load ip-adapter into unet
|
|
||||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
|
||||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
||||||
|
|
||||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
|
||||||
if extra_loras != {}:
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
logger.warning("PEFT backend is required to load these weights.")
|
|
||||||
else:
|
|
||||||
# apply the IP Adapter Face ID LoRA weights
|
|
||||||
peft_config = getattr(unet, "peft_config", {})
|
|
||||||
for k, lora in extra_loras.items():
|
|
||||||
if f"faceid_{k}" not in peft_config:
|
|
||||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
|
||||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
|
||||||
|
|
||||||
def set_ip_adapter_scale(self, scale):
|
|
||||||
"""
|
|
||||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
|
||||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
# To use original IP-Adapter
|
|
||||||
scale = 1.0
|
|
||||||
pipeline.set_ip_adapter_scale(scale)
|
|
||||||
|
|
||||||
# To use style block only
|
|
||||||
scale = {
|
|
||||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
|
||||||
}
|
|
||||||
pipeline.set_ip_adapter_scale(scale)
|
|
||||||
|
|
||||||
# To use style+layout blocks
|
|
||||||
scale = {
|
|
||||||
"down": {"block_2": [0.0, 1.0]},
|
|
||||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
|
||||||
}
|
|
||||||
pipeline.set_ip_adapter_scale(scale)
|
|
||||||
|
|
||||||
# To use style and layout from 2 reference images
|
|
||||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
|
||||||
pipeline.set_ip_adapter_scale(scales)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
|
||||||
if not isinstance(scale, list):
|
|
||||||
scale = [scale]
|
|
||||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
|
||||||
|
|
||||||
for attn_name, attn_processor in unet.attn_processors.items():
|
|
||||||
if isinstance(
|
|
||||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
|
||||||
):
|
|
||||||
if len(scale_configs) != len(attn_processor.scale):
|
|
||||||
raise ValueError(
|
|
||||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
|
||||||
)
|
|
||||||
elif len(scale_configs) == 1:
|
|
||||||
scale_configs = scale_configs * len(attn_processor.scale)
|
|
||||||
for i, scale_config in enumerate(scale_configs):
|
|
||||||
if isinstance(scale_config, dict):
|
|
||||||
for k, s in scale_config.items():
|
|
||||||
if attn_name.startswith(k):
|
|
||||||
attn_processor.scale[i] = s
|
|
||||||
else:
|
|
||||||
attn_processor.scale[i] = scale_config
|
|
||||||
|
|
||||||
def unload_ip_adapter(self):
|
|
||||||
"""
|
|
||||||
Unloads the IP Adapter weights
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
|
||||||
>>> pipeline.unload_ip_adapter()
|
|
||||||
>>> ...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# remove CLIP image encoder
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
|
||||||
self.image_encoder = None
|
|
||||||
self.register_to_config(image_encoder=[None, None])
|
|
||||||
|
|
||||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
|
||||||
# the feature_extractor later
|
|
||||||
if not hasattr(self, "safety_checker"):
|
|
||||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
|
||||||
self.feature_extractor = None
|
|
||||||
self.register_to_config(feature_extractor=[None, None])
|
|
||||||
|
|
||||||
# remove hidden encoder
|
|
||||||
self.unet.encoder_hid_proj = None
|
|
||||||
self.unet.config.encoder_hid_dim_type = None
|
|
||||||
|
|
||||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
|
||||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
|
||||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
|
||||||
self.unet.text_encoder_hid_proj = None
|
|
||||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
|
||||||
|
|
||||||
# restore original Unet attention processors layers
|
|
||||||
attn_procs = {}
|
|
||||||
for name, value in self.unet.attn_processors.items():
|
|
||||||
attn_processor_class = (
|
|
||||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
|
||||||
)
|
|
||||||
attn_procs[name] = (
|
|
||||||
attn_processor_class
|
|
||||||
if isinstance(
|
|
||||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
|
||||||
)
|
|
||||||
else value.__class__()
|
|
||||||
)
|
|
||||||
self.unet.set_attn_processor(attn_procs)
|
|
||||||
|
|
||||||
|
|
||||||
class FluxIPAdapterMixin:
|
|
||||||
"""Mixin for handling Flux IP Adapters."""
|
|
||||||
|
|
||||||
@validate_hf_hub_args
|
|
||||||
def load_ip_adapter(
|
|
||||||
self,
|
|
||||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
|
||||||
weight_name: Union[str, List[str]],
|
|
||||||
subfolder: Optional[Union[str, List[str]]] = "",
|
|
||||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
|
||||||
image_encoder_subfolder: Optional[str] = "",
|
|
||||||
image_encoder_dtype: torch.dtype = torch.float16,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
|
||||||
Can be either:
|
|
||||||
|
|
||||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
|
||||||
the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
|
||||||
with [`ModelMixin.save_pretrained`].
|
|
||||||
- A [torch state
|
|
||||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
|
||||||
subfolder (`str` or `List[str]`):
|
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
|
||||||
list is passed, it should have the same length as `weight_name`.
|
|
||||||
weight_name (`str` or `List[str]`):
|
|
||||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
|
||||||
`weight_name`.
|
|
||||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
|
||||||
Can be either:
|
|
||||||
|
|
||||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
|
||||||
hosted on the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
|
||||||
with [`ModelMixin.save_pretrained`].
|
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
||||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
||||||
is not used.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
|
||||||
won't be downloaded from the Hub.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
||||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
||||||
allowed by Git.
|
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
|
||||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
|
||||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
|
||||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
|
||||||
argument to `True` will raise an error.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# handle the list inputs for multiple IP Adapters
|
|
||||||
if not isinstance(weight_name, list):
|
|
||||||
weight_name = [weight_name]
|
|
||||||
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
|
||||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
|
||||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
|
||||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
|
||||||
|
|
||||||
if not isinstance(subfolder, list):
|
|
||||||
subfolder = [subfolder]
|
|
||||||
if len(subfolder) == 1:
|
|
||||||
subfolder = subfolder * len(weight_name)
|
|
||||||
|
|
||||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
|
||||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
|
||||||
|
|
||||||
if len(weight_name) != len(subfolder):
|
|
||||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
|
||||||
|
|
||||||
# Load the main state dict first.
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
user_agent = {
|
|
||||||
"file_type": "attn_procs_weights",
|
|
||||||
"framework": "pytorch",
|
|
||||||
}
|
|
||||||
state_dicts = []
|
|
||||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
|
||||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
|
||||||
):
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weights_name=weight_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
if weight_name.endswith(".safetensors"):
|
|
||||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
|
||||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
||||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
|
||||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
|
||||||
for key in f.keys():
|
|
||||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
|
||||||
diffusers_name = ".".join(key.split(".")[1:])
|
|
||||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
|
||||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
|
||||||
diffusers_name = (
|
|
||||||
".".join(key.split(".")[1:])
|
|
||||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
|
||||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
|
||||||
.replace("processor.", "")
|
|
||||||
)
|
|
||||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
|
||||||
else:
|
|
||||||
state_dict = load_state_dict(model_file)
|
|
||||||
else:
|
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
|
||||||
|
|
||||||
keys = list(state_dict.keys())
|
|
||||||
if keys != ["image_proj", "ip_adapter"]:
|
|
||||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
|
||||||
|
|
||||||
state_dicts.append(state_dict)
|
|
||||||
|
|
||||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
|
||||||
if image_encoder_pretrained_model_name_or_path is not None:
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
|
||||||
image_encoder = (
|
|
||||||
CLIPVisionModelWithProjection.from_pretrained(
|
|
||||||
image_encoder_pretrained_model_name_or_path,
|
|
||||||
subfolder=image_encoder_subfolder,
|
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
torch_dtype=image_encoder_dtype,
|
|
||||||
)
|
|
||||||
.to(self.device)
|
|
||||||
.eval()
|
|
||||||
)
|
|
||||||
self.register_modules(image_encoder=image_encoder)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
|
||||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# create feature extractor if it has not been registered to the pipeline yet
|
|
||||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
|
||||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
|
||||||
default_clip_size = 224
|
|
||||||
clip_image_size = (
|
|
||||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
|
||||||
)
|
|
||||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
|
||||||
self.register_modules(feature_extractor=feature_extractor)
|
|
||||||
|
|
||||||
# load ip-adapter into transformer
|
|
||||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
||||||
|
|
||||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
|
||||||
"""
|
|
||||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
|
||||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
|
||||||
|
|
||||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
|
||||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
|
||||||
number of IP adapters and each must match the number of blocks.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
# To use original IP-Adapter
|
|
||||||
scale = 1.0
|
|
||||||
pipeline.set_ip_adapter_scale(scale)
|
|
||||||
|
|
||||||
|
|
||||||
def LinearStrengthModel(start, finish, size):
|
|
||||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
|
||||||
|
|
||||||
|
|
||||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
|
||||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
scale_type = Union[int, float]
|
|
||||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
|
||||||
num_layers = self.transformer.config.num_layers
|
|
||||||
|
|
||||||
# Single value for all layers of all IP-Adapters
|
|
||||||
if isinstance(scale, scale_type):
|
|
||||||
scale = [scale for _ in range(num_ip_adapters)]
|
|
||||||
# List of per-layer scales for a single IP-Adapter
|
|
||||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
|
||||||
scale = [scale]
|
|
||||||
# Invalid scale type
|
|
||||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
|
||||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
|
||||||
|
|
||||||
if len(scale) != num_ip_adapters:
|
|
||||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
|
||||||
|
|
||||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
|
||||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Scalars are transformed to lists with length num_layers
|
|
||||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
|
||||||
|
|
||||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
|
||||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
|
||||||
attn_processor.scale = scale
|
|
||||||
|
|
||||||
def unload_ip_adapter(self):
|
|
||||||
"""
|
|
||||||
Unloads the IP Adapter weights
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
|
||||||
>>> pipeline.unload_ip_adapter()
|
|
||||||
>>> ...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# remove CLIP image encoder
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
|
||||||
self.image_encoder = None
|
|
||||||
self.register_to_config(image_encoder=[None, None])
|
|
||||||
|
|
||||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
|
||||||
# the feature_extractor later
|
|
||||||
if not hasattr(self, "safety_checker"):
|
|
||||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
|
||||||
self.feature_extractor = None
|
|
||||||
self.register_to_config(feature_extractor=[None, None])
|
|
||||||
|
|
||||||
# remove hidden encoder
|
|
||||||
self.transformer.encoder_hid_proj = None
|
|
||||||
self.transformer.config.encoder_hid_dim_type = None
|
|
||||||
|
|
||||||
# restore original Transformer attention processors layers
|
|
||||||
attn_procs = {}
|
|
||||||
for name, value in self.transformer.attn_processors.items():
|
|
||||||
attn_processor_class = FluxAttnProcessor2_0()
|
|
||||||
attn_procs[name] = (
|
|
||||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
|
||||||
)
|
|
||||||
self.transformer.set_attn_processor(attn_procs)
|
|
||||||
|
|
||||||
|
|
||||||
class SD3IPAdapterMixin:
|
|
||||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_ip_adapter_active(self) -> bool:
|
|
||||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
|
||||||
|
|
||||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
|
||||||
the image context is irrelevant.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
|
||||||
"""
|
|
||||||
scales = [
|
|
||||||
attn_proc.scale
|
|
||||||
for attn_proc in self.transformer.attn_processors.values()
|
|
||||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
|
||||||
]
|
|
||||||
|
|
||||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
|
||||||
|
|
||||||
@validate_hf_hub_args
|
|
||||||
def load_ip_adapter(
|
|
||||||
self,
|
|
||||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
|
||||||
weight_name: str = "ip-adapter.safetensors",
|
|
||||||
subfolder: Optional[str] = None,
|
|
||||||
image_encoder_folder: Optional[str] = "image_encoder",
|
|
||||||
**kwargs,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
|
||||||
Can be either:
|
|
||||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
|
||||||
the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
|
||||||
with [`ModelMixin.save_pretrained`].
|
|
||||||
- A [torch state
|
|
||||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
|
||||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
|
||||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
|
||||||
`subfolder`.
|
|
||||||
subfolder (`str`, *optional*):
|
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
|
||||||
list is passed, it should have the same length as `weight_name`.
|
|
||||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
|
||||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
|
||||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
|
||||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
|
||||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
|
||||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
|
||||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
||||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
||||||
is not used.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
|
||||||
won't be downloaded from the Hub.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
||||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
||||||
allowed by Git.
|
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
|
||||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
|
||||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
|
||||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
|
||||||
argument to `True` will raise an error.
|
|
||||||
"""
|
|
||||||
# Load the main state dict first
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
user_agent = {
|
|
||||||
"file_type": "attn_procs_weights",
|
|
||||||
"framework": "pytorch",
|
|
||||||
}
|
|
||||||
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weights_name=weight_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
if weight_name.endswith(".safetensors"):
|
|
||||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
|
||||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
|
||||||
for key in f.keys():
|
|
||||||
if key.startswith("image_proj."):
|
|
||||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
|
||||||
elif key.startswith("ip_adapter."):
|
|
||||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
|
||||||
else:
|
|
||||||
state_dict = load_state_dict(model_file)
|
|
||||||
else:
|
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
|
||||||
|
|
||||||
keys = list(state_dict.keys())
|
|
||||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
|
||||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
|
||||||
|
|
||||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
|
||||||
if image_encoder_folder is not None:
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
|
||||||
if image_encoder_folder.count("/") == 0:
|
|
||||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
|
||||||
else:
|
|
||||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
|
||||||
|
|
||||||
# Commons args for loading image encoder and image processor
|
|
||||||
kwargs = {
|
|
||||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
|
||||||
"cache_dir": cache_dir,
|
|
||||||
"local_files_only": local_files_only,
|
|
||||||
}
|
|
||||||
|
|
||||||
self.register_modules(
|
|
||||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
|
||||||
image_encoder=SiglipVisionModel.from_pretrained(
|
|
||||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
|
||||||
).to(self.device),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
|
||||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Load IP-Adapter into transformer
|
|
||||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
||||||
|
|
||||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
|
||||||
"""
|
|
||||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
|
||||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
|
||||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
|
||||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
|
||||||
>>> ...
|
|
||||||
```
|
|
||||||
|
|
||||||
Args:
|
|
||||||
scale (float):
|
|
||||||
IP-Adapter scale to be set.
|
|
||||||
|
|
||||||
"""
|
|
||||||
for attn_processor in self.transformer.attn_processors.values():
|
|
||||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
|
||||||
attn_processor.scale = scale
|
|
||||||
|
|
||||||
def unload_ip_adapter(self) -> None:
|
|
||||||
"""
|
|
||||||
Unloads the IP Adapter weights.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
|
||||||
>>> pipeline.unload_ip_adapter()
|
|
||||||
>>> ...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
# Remove image encoder
|
|
||||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
|
||||||
self.image_encoder = None
|
|
||||||
self.register_to_config(image_encoder=None)
|
|
||||||
|
|
||||||
# Remove feature extractor
|
|
||||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
|
||||||
self.feature_extractor = None
|
|
||||||
self.register_to_config(feature_extractor=None)
|
|
||||||
|
|
||||||
# Remove image projection
|
|
||||||
self.transformer.image_proj = None
|
|
||||||
|
|
||||||
# Restore original attention processors layers
|
|
||||||
attn_procs = {
|
|
||||||
name: (
|
|
||||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
|
||||||
)
|
|
||||||
for name, value in self.transformer.attn_processors.items()
|
|
||||||
}
|
|
||||||
self.transformer.set_attn_processor(attn_procs)
|
|
||||||
|
|||||||
9
src/diffusers/loaders/ip_adapter/__init__.py
Normal file
9
src/diffusers/loaders/ip_adapter/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
from ...utils.import_utils import is_torch_available, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||||
|
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||||
879
src/diffusers/loaders/ip_adapter/ip_adapter.py
Normal file
879
src/diffusers/loaders/ip_adapter/ip_adapter.py
Normal file
@@ -0,0 +1,879 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||||
|
from ...utils import (
|
||||||
|
USE_PEFT_BACKEND,
|
||||||
|
_get_detailed_type,
|
||||||
|
_get_model_file,
|
||||||
|
_is_valid_type,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_torch_version,
|
||||||
|
is_transformers_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from ..unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||||
|
|
||||||
|
from ...models.attention_processor import (
|
||||||
|
AttnProcessor,
|
||||||
|
AttnProcessor2_0,
|
||||||
|
FluxAttnProcessor2_0,
|
||||||
|
FluxIPAdapterJointAttnProcessor2_0,
|
||||||
|
IPAdapterAttnProcessor,
|
||||||
|
IPAdapterAttnProcessor2_0,
|
||||||
|
IPAdapterXFormersAttnProcessor,
|
||||||
|
JointAttnProcessor2_0,
|
||||||
|
SD3IPAdapterJointAttnProcessor2_0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterMixin:
|
||||||
|
"""Mixin for handling IP Adapters."""
|
||||||
|
|
||||||
|
@validate_hf_hub_args
|
||||||
|
def load_ip_adapter(
|
||||||
|
self,
|
||||||
|
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||||
|
subfolder: Union[str, List[str]],
|
||||||
|
weight_name: Union[str, List[str]],
|
||||||
|
image_encoder_folder: Optional[str] = "image_encoder",
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||||
|
the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||||
|
with [`ModelMixin.save_pretrained`].
|
||||||
|
- A [torch state
|
||||||
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||||
|
subfolder (`str` or `List[str]`):
|
||||||
|
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||||
|
list is passed, it should have the same length as `weight_name`.
|
||||||
|
weight_name (`str` or `List[str]`):
|
||||||
|
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||||
|
`subfolder`.
|
||||||
|
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||||
|
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||||
|
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||||
|
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||||
|
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||||
|
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||||
|
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
|
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||||
|
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
|
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||||
|
argument to `True` will raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# handle the list inputs for multiple IP Adapters
|
||||||
|
if not isinstance(weight_name, list):
|
||||||
|
weight_name = [weight_name]
|
||||||
|
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||||
|
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||||
|
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||||
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||||
|
|
||||||
|
if not isinstance(subfolder, list):
|
||||||
|
subfolder = [subfolder]
|
||||||
|
if len(subfolder) == 1:
|
||||||
|
subfolder = subfolder * len(weight_name)
|
||||||
|
|
||||||
|
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||||
|
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||||
|
|
||||||
|
if len(weight_name) != len(subfolder):
|
||||||
|
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||||
|
|
||||||
|
# Load the main state dict first.
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"file_type": "attn_procs_weights",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
state_dicts = []
|
||||||
|
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||||
|
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||||
|
):
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
if weight_name.endswith(".safetensors"):
|
||||||
|
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||||
|
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||||
|
elif key.startswith("ip_adapter."):
|
||||||
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(model_file)
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||||
|
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||||
|
|
||||||
|
state_dicts.append(state_dict)
|
||||||
|
|
||||||
|
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||||
|
if image_encoder_folder is not None:
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||||
|
if image_encoder_folder.count("/") == 0:
|
||||||
|
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||||
|
else:
|
||||||
|
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||||
|
|
||||||
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
subfolder=image_encoder_subfolder,
|
||||||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
torch_dtype=self.dtype,
|
||||||
|
).to(self.device)
|
||||||
|
self.register_modules(image_encoder=image_encoder)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||||
|
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# create feature extractor if it has not been registered to the pipeline yet
|
||||||
|
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||||
|
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||||
|
default_clip_size = 224
|
||||||
|
clip_image_size = (
|
||||||
|
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||||
|
)
|
||||||
|
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||||
|
self.register_modules(feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
# load ip-adapter into unet
|
||||||
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||||
|
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||||
|
|
||||||
|
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||||
|
if extra_loras != {}:
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
logger.warning("PEFT backend is required to load these weights.")
|
||||||
|
else:
|
||||||
|
# apply the IP Adapter Face ID LoRA weights
|
||||||
|
peft_config = getattr(unet, "peft_config", {})
|
||||||
|
for k, lora in extra_loras.items():
|
||||||
|
if f"faceid_{k}" not in peft_config:
|
||||||
|
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||||
|
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||||
|
|
||||||
|
def set_ip_adapter_scale(self, scale):
|
||||||
|
"""
|
||||||
|
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||||
|
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# To use original IP-Adapter
|
||||||
|
scale = 1.0
|
||||||
|
pipeline.set_ip_adapter_scale(scale)
|
||||||
|
|
||||||
|
# To use style block only
|
||||||
|
scale = {
|
||||||
|
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||||
|
}
|
||||||
|
pipeline.set_ip_adapter_scale(scale)
|
||||||
|
|
||||||
|
# To use style+layout blocks
|
||||||
|
scale = {
|
||||||
|
"down": {"block_2": [0.0, 1.0]},
|
||||||
|
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||||
|
}
|
||||||
|
pipeline.set_ip_adapter_scale(scale)
|
||||||
|
|
||||||
|
# To use style and layout from 2 reference images
|
||||||
|
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||||
|
pipeline.set_ip_adapter_scale(scales)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||||
|
if not isinstance(scale, list):
|
||||||
|
scale = [scale]
|
||||||
|
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||||
|
|
||||||
|
for attn_name, attn_processor in unet.attn_processors.items():
|
||||||
|
if isinstance(
|
||||||
|
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||||
|
):
|
||||||
|
if len(scale_configs) != len(attn_processor.scale):
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||||
|
)
|
||||||
|
elif len(scale_configs) == 1:
|
||||||
|
scale_configs = scale_configs * len(attn_processor.scale)
|
||||||
|
for i, scale_config in enumerate(scale_configs):
|
||||||
|
if isinstance(scale_config, dict):
|
||||||
|
for k, s in scale_config.items():
|
||||||
|
if attn_name.startswith(k):
|
||||||
|
attn_processor.scale[i] = s
|
||||||
|
else:
|
||||||
|
attn_processor.scale[i] = scale_config
|
||||||
|
|
||||||
|
def unload_ip_adapter(self):
|
||||||
|
"""
|
||||||
|
Unloads the IP Adapter weights
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||||
|
>>> pipeline.unload_ip_adapter()
|
||||||
|
>>> ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# remove CLIP image encoder
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||||
|
self.image_encoder = None
|
||||||
|
self.register_to_config(image_encoder=[None, None])
|
||||||
|
|
||||||
|
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||||
|
# the feature_extractor later
|
||||||
|
if not hasattr(self, "safety_checker"):
|
||||||
|
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||||
|
self.feature_extractor = None
|
||||||
|
self.register_to_config(feature_extractor=[None, None])
|
||||||
|
|
||||||
|
# remove hidden encoder
|
||||||
|
self.unet.encoder_hid_proj = None
|
||||||
|
self.unet.config.encoder_hid_dim_type = None
|
||||||
|
|
||||||
|
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||||
|
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||||
|
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||||
|
self.unet.text_encoder_hid_proj = None
|
||||||
|
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||||
|
|
||||||
|
# restore original Unet attention processors layers
|
||||||
|
attn_procs = {}
|
||||||
|
for name, value in self.unet.attn_processors.items():
|
||||||
|
attn_processor_class = (
|
||||||
|
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||||
|
)
|
||||||
|
attn_procs[name] = (
|
||||||
|
attn_processor_class
|
||||||
|
if isinstance(
|
||||||
|
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||||
|
)
|
||||||
|
else value.__class__()
|
||||||
|
)
|
||||||
|
self.unet.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxIPAdapterMixin:
|
||||||
|
"""Mixin for handling Flux IP Adapters."""
|
||||||
|
|
||||||
|
@validate_hf_hub_args
|
||||||
|
def load_ip_adapter(
|
||||||
|
self,
|
||||||
|
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||||
|
weight_name: Union[str, List[str]],
|
||||||
|
subfolder: Optional[Union[str, List[str]]] = "",
|
||||||
|
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||||
|
image_encoder_subfolder: Optional[str] = "",
|
||||||
|
image_encoder_dtype: torch.dtype = torch.float16,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||||
|
the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||||
|
with [`ModelMixin.save_pretrained`].
|
||||||
|
- A [torch state
|
||||||
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||||
|
subfolder (`str` or `List[str]`):
|
||||||
|
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||||
|
list is passed, it should have the same length as `weight_name`.
|
||||||
|
weight_name (`str` or `List[str]`):
|
||||||
|
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||||
|
`weight_name`.
|
||||||
|
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||||
|
Can be either:
|
||||||
|
|
||||||
|
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||||
|
hosted on the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||||
|
with [`ModelMixin.save_pretrained`].
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
|
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||||
|
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
|
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||||
|
argument to `True` will raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# handle the list inputs for multiple IP Adapters
|
||||||
|
if not isinstance(weight_name, list):
|
||||||
|
weight_name = [weight_name]
|
||||||
|
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||||
|
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||||
|
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||||
|
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||||
|
|
||||||
|
if not isinstance(subfolder, list):
|
||||||
|
subfolder = [subfolder]
|
||||||
|
if len(subfolder) == 1:
|
||||||
|
subfolder = subfolder * len(weight_name)
|
||||||
|
|
||||||
|
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||||
|
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||||
|
|
||||||
|
if len(weight_name) != len(subfolder):
|
||||||
|
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||||
|
|
||||||
|
# Load the main state dict first.
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"file_type": "attn_procs_weights",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
state_dicts = []
|
||||||
|
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||||
|
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||||
|
):
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
if weight_name.endswith(".safetensors"):
|
||||||
|
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||||
|
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||||
|
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||||
|
for key in f.keys():
|
||||||
|
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||||
|
diffusers_name = ".".join(key.split(".")[1:])
|
||||||
|
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||||
|
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||||
|
diffusers_name = (
|
||||||
|
".".join(key.split(".")[1:])
|
||||||
|
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||||
|
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||||
|
.replace("processor.", "")
|
||||||
|
)
|
||||||
|
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(model_file)
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
if keys != ["image_proj", "ip_adapter"]:
|
||||||
|
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||||
|
|
||||||
|
state_dicts.append(state_dict)
|
||||||
|
|
||||||
|
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||||
|
if image_encoder_pretrained_model_name_or_path is not None:
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||||
|
image_encoder = (
|
||||||
|
CLIPVisionModelWithProjection.from_pretrained(
|
||||||
|
image_encoder_pretrained_model_name_or_path,
|
||||||
|
subfolder=image_encoder_subfolder,
|
||||||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
dtype=image_encoder_dtype,
|
||||||
|
)
|
||||||
|
.to(self.device)
|
||||||
|
.eval()
|
||||||
|
)
|
||||||
|
self.register_modules(image_encoder=image_encoder)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||||
|
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# create feature extractor if it has not been registered to the pipeline yet
|
||||||
|
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||||
|
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||||
|
default_clip_size = 224
|
||||||
|
clip_image_size = (
|
||||||
|
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||||
|
)
|
||||||
|
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||||
|
self.register_modules(feature_extractor=feature_extractor)
|
||||||
|
|
||||||
|
# load ip-adapter into transformer
|
||||||
|
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||||
|
|
||||||
|
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||||
|
"""
|
||||||
|
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||||
|
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||||
|
|
||||||
|
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||||
|
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||||
|
number of IP adapters and each must match the number of blocks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
# To use original IP-Adapter
|
||||||
|
scale = 1.0
|
||||||
|
pipeline.set_ip_adapter_scale(scale)
|
||||||
|
|
||||||
|
|
||||||
|
def LinearStrengthModel(start, finish, size):
|
||||||
|
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||||
|
|
||||||
|
|
||||||
|
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||||
|
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
scale_type = Union[int, float]
|
||||||
|
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||||
|
num_layers = self.transformer.config.num_layers
|
||||||
|
|
||||||
|
# Single value for all layers of all IP-Adapters
|
||||||
|
if isinstance(scale, scale_type):
|
||||||
|
scale = [scale for _ in range(num_ip_adapters)]
|
||||||
|
# List of per-layer scales for a single IP-Adapter
|
||||||
|
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||||
|
scale = [scale]
|
||||||
|
# Invalid scale type
|
||||||
|
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||||
|
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||||
|
|
||||||
|
if len(scale) != num_ip_adapters:
|
||||||
|
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||||
|
|
||||||
|
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||||
|
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Scalars are transformed to lists with length num_layers
|
||||||
|
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||||
|
|
||||||
|
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||||
|
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||||
|
attn_processor.scale = scale
|
||||||
|
|
||||||
|
def unload_ip_adapter(self):
|
||||||
|
"""
|
||||||
|
Unloads the IP Adapter weights
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||||
|
>>> pipeline.unload_ip_adapter()
|
||||||
|
>>> ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# remove CLIP image encoder
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||||
|
self.image_encoder = None
|
||||||
|
self.register_to_config(image_encoder=[None, None])
|
||||||
|
|
||||||
|
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||||
|
# the feature_extractor later
|
||||||
|
if not hasattr(self, "safety_checker"):
|
||||||
|
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||||
|
self.feature_extractor = None
|
||||||
|
self.register_to_config(feature_extractor=[None, None])
|
||||||
|
|
||||||
|
# remove hidden encoder
|
||||||
|
self.transformer.encoder_hid_proj = None
|
||||||
|
self.transformer.config.encoder_hid_dim_type = None
|
||||||
|
|
||||||
|
# restore original Transformer attention processors layers
|
||||||
|
attn_procs = {}
|
||||||
|
for name, value in self.transformer.attn_processors.items():
|
||||||
|
attn_processor_class = FluxAttnProcessor2_0()
|
||||||
|
attn_procs[name] = (
|
||||||
|
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||||
|
)
|
||||||
|
self.transformer.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
|
||||||
|
class SD3IPAdapterMixin:
|
||||||
|
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_ip_adapter_active(self) -> bool:
|
||||||
|
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||||
|
|
||||||
|
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||||
|
the image context is irrelevant.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||||
|
"""
|
||||||
|
scales = [
|
||||||
|
attn_proc.scale
|
||||||
|
for attn_proc in self.transformer.attn_processors.values()
|
||||||
|
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||||
|
]
|
||||||
|
|
||||||
|
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||||
|
|
||||||
|
@validate_hf_hub_args
|
||||||
|
def load_ip_adapter(
|
||||||
|
self,
|
||||||
|
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||||
|
weight_name: str = "ip-adapter.safetensors",
|
||||||
|
subfolder: Optional[str] = None,
|
||||||
|
image_encoder_folder: Optional[str] = "image_encoder",
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||||
|
Can be either:
|
||||||
|
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||||
|
the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||||
|
with [`ModelMixin.save_pretrained`].
|
||||||
|
- A [torch state
|
||||||
|
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||||
|
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||||
|
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||||
|
`subfolder`.
|
||||||
|
subfolder (`str`, *optional*):
|
||||||
|
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||||
|
list is passed, it should have the same length as `weight_name`.
|
||||||
|
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||||
|
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||||
|
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||||
|
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||||
|
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||||
|
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||||
|
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
|
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||||
|
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
|
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||||
|
argument to `True` will raise an error.
|
||||||
|
"""
|
||||||
|
# Load the main state dict first
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"file_type": "attn_procs_weights",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
if weight_name.endswith(".safetensors"):
|
||||||
|
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||||
|
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||||
|
elif key.startswith("ip_adapter."):
|
||||||
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||||
|
else:
|
||||||
|
state_dict = load_state_dict(model_file)
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||||
|
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||||
|
|
||||||
|
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||||
|
if image_encoder_folder is not None:
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||||
|
if image_encoder_folder.count("/") == 0:
|
||||||
|
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||||
|
else:
|
||||||
|
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||||
|
|
||||||
|
# Commons args for loading image encoder and image processor
|
||||||
|
kwargs = {
|
||||||
|
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.register_modules(
|
||||||
|
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||||
|
image_encoder=SiglipVisionModel.from_pretrained(
|
||||||
|
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||||
|
).to(self.device),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||||
|
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load IP-Adapter into transformer
|
||||||
|
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||||
|
|
||||||
|
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||||
|
"""
|
||||||
|
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||||
|
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||||
|
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||||
|
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||||
|
>>> ...
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
scale (float):
|
||||||
|
IP-Adapter scale to be set.
|
||||||
|
|
||||||
|
"""
|
||||||
|
for attn_processor in self.transformer.attn_processors.values():
|
||||||
|
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||||
|
attn_processor.scale = scale
|
||||||
|
|
||||||
|
def unload_ip_adapter(self) -> None:
|
||||||
|
"""
|
||||||
|
Unloads the IP Adapter weights.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||||
|
>>> pipeline.unload_ip_adapter()
|
||||||
|
>>> ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
# Remove image encoder
|
||||||
|
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||||
|
self.image_encoder = None
|
||||||
|
self.register_to_config(image_encoder=None)
|
||||||
|
|
||||||
|
# Remove feature extractor
|
||||||
|
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||||
|
self.feature_extractor = None
|
||||||
|
self.register_to_config(feature_extractor=None)
|
||||||
|
|
||||||
|
# Remove image projection
|
||||||
|
self.transformer.image_proj = None
|
||||||
|
|
||||||
|
# Restore original attention processors layers
|
||||||
|
attn_procs = {
|
||||||
|
name: (
|
||||||
|
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||||
|
)
|
||||||
|
for name, value in self.transformer.attn_processors.items()
|
||||||
|
}
|
||||||
|
self.transformer.set_attn_processor(attn_procs)
|
||||||
168
src/diffusers/loaders/ip_adapter/transformer_flux.py
Normal file
168
src/diffusers/loaders/ip_adapter/transformer_flux.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
|
||||||
|
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||||
|
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class FluxTransformer2DLoadersMixin:
|
||||||
|
"""
|
||||||
|
Load layers into a [`FluxTransformer2DModel`].
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
else:
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
updated_state_dict = {}
|
||||||
|
image_projection = None
|
||||||
|
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||||
|
|
||||||
|
if "proj.weight" in state_dict:
|
||||||
|
# IP-Adapter
|
||||||
|
num_image_text_embeds = 4
|
||||||
|
if state_dict["proj.weight"].shape[0] == 65536:
|
||||||
|
num_image_text_embeds = 16
|
||||||
|
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||||
|
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||||
|
|
||||||
|
with init_context():
|
||||||
|
image_projection = ImageProjection(
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
image_embed_dim=clip_embeddings_dim,
|
||||||
|
num_image_text_embeds=num_image_text_embeds,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
diffusers_name = key.replace("proj", "image_embeds")
|
||||||
|
updated_state_dict[diffusers_name] = value
|
||||||
|
|
||||||
|
if not low_cpu_mem_usage:
|
||||||
|
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||||
|
else:
|
||||||
|
device_map = {"": self.device}
|
||||||
|
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||||
|
|
||||||
|
return image_projection
|
||||||
|
|
||||||
|
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||||
|
from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||||
|
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
else:
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# set ip-adapter cross-attention processors & load state_dict
|
||||||
|
attn_procs = {}
|
||||||
|
key_id = 0
|
||||||
|
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||||
|
for name in self.attn_processors.keys():
|
||||||
|
if name.startswith("single_transformer_blocks"):
|
||||||
|
attn_processor_class = self.attn_processors[name].__class__
|
||||||
|
attn_procs[name] = attn_processor_class()
|
||||||
|
else:
|
||||||
|
cross_attention_dim = self.config.joint_attention_dim
|
||||||
|
hidden_size = self.inner_dim
|
||||||
|
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||||
|
num_image_text_embeds = []
|
||||||
|
for state_dict in state_dicts:
|
||||||
|
if "proj.weight" in state_dict["image_proj"]:
|
||||||
|
num_image_text_embed = 4
|
||||||
|
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||||
|
num_image_text_embed = 16
|
||||||
|
# IP-Adapter
|
||||||
|
num_image_text_embeds += [num_image_text_embed]
|
||||||
|
|
||||||
|
with init_context():
|
||||||
|
attn_procs[name] = attn_processor_class(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
cross_attention_dim=cross_attention_dim,
|
||||||
|
scale=1.0,
|
||||||
|
num_tokens=num_image_text_embeds,
|
||||||
|
dtype=self.dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
value_dict = {}
|
||||||
|
for i, state_dict in enumerate(state_dicts):
|
||||||
|
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||||
|
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||||
|
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||||
|
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||||
|
|
||||||
|
if not low_cpu_mem_usage:
|
||||||
|
attn_procs[name].load_state_dict(value_dict)
|
||||||
|
else:
|
||||||
|
device_map = {"": self.device}
|
||||||
|
dtype = self.dtype
|
||||||
|
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||||
|
|
||||||
|
key_id += 1
|
||||||
|
|
||||||
|
return attn_procs
|
||||||
|
|
||||||
|
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||||
|
if not isinstance(state_dicts, list):
|
||||||
|
state_dicts = [state_dicts]
|
||||||
|
|
||||||
|
self.encoder_hid_proj = None
|
||||||
|
|
||||||
|
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||||
|
self.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
image_projection_layers = []
|
||||||
|
for state_dict in state_dicts:
|
||||||
|
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||||
|
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||||
|
)
|
||||||
|
image_projection_layers.append(image_projection_layer)
|
||||||
|
|
||||||
|
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||||
|
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||||
170
src/diffusers/loaders/ip_adapter/transformer_sd3.py
Normal file
170
src/diffusers/loaders/ip_adapter/transformer_sd3.py
Normal file
@@ -0,0 +1,170 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||||
|
from ...models.embeddings import IPAdapterTimeImageProjection
|
||||||
|
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||||
|
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SD3Transformer2DLoadersMixin:
|
||||||
|
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||||
|
|
||||||
|
def _convert_ip_adapter_attn_to_diffusers(
|
||||||
|
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
|
) -> Dict:
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
else:
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# IP-Adapter cross attention parameters
|
||||||
|
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||||
|
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||||
|
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
||||||
|
|
||||||
|
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||||
|
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||||
|
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||||
|
for key, weights in state_dict.items():
|
||||||
|
idx, name = key.split(".", maxsplit=1)
|
||||||
|
layer_state_dict[int(idx)][name] = weights
|
||||||
|
|
||||||
|
# Create IP-Adapter attention processor & load state_dict
|
||||||
|
attn_procs = {}
|
||||||
|
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||||
|
for idx, name in enumerate(self.attn_processors.keys()):
|
||||||
|
with init_context():
|
||||||
|
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||||
|
head_dim=self.config.attention_head_dim,
|
||||||
|
timesteps_emb_dim=timesteps_emb_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not low_cpu_mem_usage:
|
||||||
|
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||||
|
else:
|
||||||
|
device_map = {"": self.device}
|
||||||
|
load_model_dict_into_meta(
|
||||||
|
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_procs
|
||||||
|
|
||||||
|
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||||
|
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||||
|
) -> IPAdapterTimeImageProjection:
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import init_empty_weights
|
||||||
|
|
||||||
|
else:
|
||||||
|
low_cpu_mem_usage = False
|
||||||
|
logger.warning(
|
||||||
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
|
" install accelerate\n```\n."
|
||||||
|
)
|
||||||
|
|
||||||
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
|
" `low_cpu_mem_usage=False`."
|
||||||
|
)
|
||||||
|
|
||||||
|
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||||
|
|
||||||
|
# Convert to diffusers
|
||||||
|
updated_state_dict = {}
|
||||||
|
for key, value in state_dict.items():
|
||||||
|
# InstantX/SD3.5-Large-IP-Adapter
|
||||||
|
if key.startswith("layers."):
|
||||||
|
idx = key.split(".")[1]
|
||||||
|
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
||||||
|
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
||||||
|
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
||||||
|
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
||||||
|
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
||||||
|
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
||||||
|
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
||||||
|
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
||||||
|
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
||||||
|
updated_state_dict[key] = value
|
||||||
|
|
||||||
|
# Image projetion parameters
|
||||||
|
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
||||||
|
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
||||||
|
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
||||||
|
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||||
|
num_queries = updated_state_dict["latents"].shape[1]
|
||||||
|
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
||||||
|
|
||||||
|
# Image projection
|
||||||
|
with init_context():
|
||||||
|
image_proj = IPAdapterTimeImageProjection(
|
||||||
|
embed_dim=embed_dim,
|
||||||
|
output_dim=output_dim,
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
heads=heads,
|
||||||
|
num_queries=num_queries,
|
||||||
|
timestep_in_dim=timestep_in_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not low_cpu_mem_usage:
|
||||||
|
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||||
|
else:
|
||||||
|
device_map = {"": self.device}
|
||||||
|
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||||
|
|
||||||
|
return image_proj
|
||||||
|
|
||||||
|
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||||
|
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (`Dict`):
|
||||||
|
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||||
|
"image_proj", which contains parameters for image projection net.
|
||||||
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||||
|
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||||
|
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||||
|
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||||
|
argument to `True` will raise an error.
|
||||||
|
"""
|
||||||
|
|
||||||
|
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
||||||
|
self.set_attn_processor(attn_procs)
|
||||||
|
|
||||||
|
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
||||||
25
src/diffusers/loaders/lora/__init__.py
Normal file
25
src/diffusers/loaders/lora/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
from ...utils import is_peft_available, is_torch_available, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .lora_base import LoraBaseMixin
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from .lora_pipeline import (
|
||||||
|
AmusedLoraLoaderMixin,
|
||||||
|
AuraFlowLoraLoaderMixin,
|
||||||
|
CogVideoXLoraLoaderMixin,
|
||||||
|
CogView4LoraLoaderMixin,
|
||||||
|
FluxLoraLoaderMixin,
|
||||||
|
HiDreamImageLoraLoaderMixin,
|
||||||
|
HunyuanVideoLoraLoaderMixin,
|
||||||
|
LoraLoaderMixin,
|
||||||
|
LTXVideoLoraLoaderMixin,
|
||||||
|
Lumina2LoraLoaderMixin,
|
||||||
|
Mochi1LoraLoaderMixin,
|
||||||
|
SanaLoraLoaderMixin,
|
||||||
|
SD3LoraLoaderMixin,
|
||||||
|
StableDiffusionLoraLoaderMixin,
|
||||||
|
StableDiffusionXLLoraLoaderMixin,
|
||||||
|
WanLoraLoaderMixin,
|
||||||
|
)
|
||||||
935
src/diffusers/loaders/lora/lora_base.py
Normal file
935
src/diffusers/loaders/lora/lora_base.py
Normal file
@@ -0,0 +1,935 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from huggingface_hub import model_info
|
||||||
|
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||||
|
|
||||||
|
from ...models.modeling_utils import ModelMixin, load_state_dict
|
||||||
|
from ...utils import (
|
||||||
|
USE_PEFT_BACKEND,
|
||||||
|
_get_model_file,
|
||||||
|
convert_state_dict_to_diffusers,
|
||||||
|
convert_state_dict_to_peft,
|
||||||
|
delete_adapter_layers,
|
||||||
|
deprecate,
|
||||||
|
get_adapter_name,
|
||||||
|
get_peft_kwargs,
|
||||||
|
is_accelerate_available,
|
||||||
|
is_peft_available,
|
||||||
|
is_peft_version,
|
||||||
|
is_transformers_available,
|
||||||
|
is_transformers_version,
|
||||||
|
logging,
|
||||||
|
recurse_remove_peft_layers,
|
||||||
|
scale_lora_layers,
|
||||||
|
set_adapter_layers,
|
||||||
|
set_weights_and_activate_adapters,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
|
from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||||
|
|
||||||
|
if is_peft_available():
|
||||||
|
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||||
|
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||||
|
|
||||||
|
|
||||||
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||||
|
"""
|
||||||
|
Fuses LoRAs for the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_encoder (`torch.nn.Module`):
|
||||||
|
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||||
|
attribute.
|
||||||
|
lora_scale (`float`, defaults to 1.0):
|
||||||
|
Controls how much to influence the outputs with the LoRA parameters.
|
||||||
|
safe_fusing (`bool`, defaults to `False`):
|
||||||
|
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||||
|
adapter_names (`List[str]` or `str`):
|
||||||
|
The names of the adapters to use.
|
||||||
|
"""
|
||||||
|
merge_kwargs = {"safe_merge": safe_fusing}
|
||||||
|
|
||||||
|
for module in text_encoder.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
if lora_scale != 1.0:
|
||||||
|
module.scale_layer(lora_scale)
|
||||||
|
|
||||||
|
# For BC with previous PEFT versions, we need to check the signature
|
||||||
|
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||||
|
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||||
|
if "adapter_names" in supported_merge_kwargs:
|
||||||
|
merge_kwargs["adapter_names"] = adapter_names
|
||||||
|
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||||
|
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||||
|
)
|
||||||
|
|
||||||
|
module.merge(**merge_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def unfuse_text_encoder_lora(text_encoder):
|
||||||
|
"""
|
||||||
|
Unfuses LoRAs for the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_encoder (`torch.nn.Module`):
|
||||||
|
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||||
|
attribute.
|
||||||
|
"""
|
||||||
|
for module in text_encoder.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
module.unmerge()
|
||||||
|
|
||||||
|
|
||||||
|
def set_adapters_for_text_encoder(
|
||||||
|
adapter_names: Union[List[str], str],
|
||||||
|
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||||
|
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Sets the adapter layers for the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adapter_names (`List[str]` or `str`):
|
||||||
|
The names of the adapters to use.
|
||||||
|
text_encoder (`torch.nn.Module`, *optional*):
|
||||||
|
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||||
|
attribute.
|
||||||
|
text_encoder_weights (`List[float]`, *optional*):
|
||||||
|
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||||
|
"""
|
||||||
|
if text_encoder is None:
|
||||||
|
raise ValueError(
|
||||||
|
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_weights(adapter_names, weights):
|
||||||
|
# Expand weights into a list, one entry per adapter
|
||||||
|
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||||
|
if not isinstance(weights, list):
|
||||||
|
weights = [weights] * len(adapter_names)
|
||||||
|
|
||||||
|
if len(adapter_names) != len(weights):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set None values to default of 1.0
|
||||||
|
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||||
|
weights = [w if w is not None else 1.0 for w in weights]
|
||||||
|
|
||||||
|
return weights
|
||||||
|
|
||||||
|
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||||
|
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||||
|
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||||
|
"""
|
||||||
|
Disables the LoRA layers for the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_encoder (`torch.nn.Module`, *optional*):
|
||||||
|
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||||
|
attribute.
|
||||||
|
"""
|
||||||
|
if text_encoder is None:
|
||||||
|
raise ValueError("Text Encoder not found.")
|
||||||
|
set_adapter_layers(text_encoder, enabled=False)
|
||||||
|
|
||||||
|
|
||||||
|
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||||
|
"""
|
||||||
|
Enables the LoRA layers for the text encoder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text_encoder (`torch.nn.Module`, *optional*):
|
||||||
|
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||||
|
attribute.
|
||||||
|
"""
|
||||||
|
if text_encoder is None:
|
||||||
|
raise ValueError("Text Encoder not found.")
|
||||||
|
set_adapter_layers(text_encoder, enabled=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||||
|
recurse_remove_peft_layers(text_encoder)
|
||||||
|
if getattr(text_encoder, "peft_config", None) is not None:
|
||||||
|
del text_encoder.peft_config
|
||||||
|
text_encoder._hf_peft_config_loaded = None
|
||||||
|
|
||||||
|
|
||||||
|
def _fetch_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weight_name,
|
||||||
|
use_safetensors,
|
||||||
|
local_files_only,
|
||||||
|
cache_dir,
|
||||||
|
force_download,
|
||||||
|
proxies,
|
||||||
|
token,
|
||||||
|
revision,
|
||||||
|
subfolder,
|
||||||
|
user_agent,
|
||||||
|
allow_pickle,
|
||||||
|
):
|
||||||
|
model_file = None
|
||||||
|
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||||
|
# Let's first try to load .safetensors weights
|
||||||
|
if (use_safetensors and weight_name is None) or (
|
||||||
|
weight_name is not None and weight_name.endswith(".safetensors")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
# Here we're relaxing the loading check to enable more Inference API
|
||||||
|
# friendliness where sometimes, it's not at all possible to automatically
|
||||||
|
# determine `weight_name`.
|
||||||
|
if weight_name is None:
|
||||||
|
weight_name = _best_guess_weight_name(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
file_extension=".safetensors",
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||||
|
except (IOError, safetensors.SafetensorError) as e:
|
||||||
|
if not allow_pickle:
|
||||||
|
raise e
|
||||||
|
# try loading non-safetensors weights
|
||||||
|
model_file = None
|
||||||
|
pass
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
if weight_name is None:
|
||||||
|
weight_name = _best_guess_weight_name(
|
||||||
|
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
state_dict = load_state_dict(model_file)
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path_or_dict
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _best_guess_weight_name(
|
||||||
|
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||||
|
):
|
||||||
|
if local_files_only or HF_HUB_OFFLINE:
|
||||||
|
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||||
|
|
||||||
|
targeted_files = []
|
||||||
|
|
||||||
|
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||||
|
return
|
||||||
|
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||||
|
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||||
|
else:
|
||||||
|
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||||
|
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||||
|
if len(targeted_files) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||||
|
# "optimizer" does not correspond to a LoRA checkpoint
|
||||||
|
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||||
|
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||||
|
targeted_files = list(
|
||||||
|
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||||
|
)
|
||||||
|
|
||||||
|
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||||
|
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||||
|
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||||
|
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||||
|
|
||||||
|
if len(targeted_files) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||||
|
)
|
||||||
|
weight_name = targeted_files[0]
|
||||||
|
return weight_name
|
||||||
|
|
||||||
|
|
||||||
|
def _load_lora_into_text_encoder(
|
||||||
|
state_dict,
|
||||||
|
network_alphas,
|
||||||
|
text_encoder,
|
||||||
|
prefix=None,
|
||||||
|
lora_scale=1.0,
|
||||||
|
text_encoder_name="text_encoder",
|
||||||
|
adapter_name=None,
|
||||||
|
_pipeline=None,
|
||||||
|
low_cpu_mem_usage=False,
|
||||||
|
hotswap: bool = False,
|
||||||
|
):
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
peft_kwargs = {}
|
||||||
|
if low_cpu_mem_usage:
|
||||||
|
if not is_peft_version(">=", "0.13.1"):
|
||||||
|
raise ValueError(
|
||||||
|
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||||
|
)
|
||||||
|
if not is_transformers_version(">", "4.45.2"):
|
||||||
|
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||||
|
# https://github.com/huggingface/transformers/pull/33725/
|
||||||
|
raise ValueError(
|
||||||
|
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||||
|
)
|
||||||
|
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||||
|
|
||||||
|
from peft import LoraConfig
|
||||||
|
|
||||||
|
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||||
|
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||||
|
# their prefixes.
|
||||||
|
prefix = text_encoder_name if prefix is None else prefix
|
||||||
|
|
||||||
|
# Safe prefix to check with.
|
||||||
|
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||||
|
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||||
|
|
||||||
|
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||||
|
if prefix is not None:
|
||||||
|
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||||
|
|
||||||
|
if len(state_dict) > 0:
|
||||||
|
logger.info(f"Loading {prefix}.")
|
||||||
|
rank = {}
|
||||||
|
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||||
|
|
||||||
|
# convert state dict
|
||||||
|
state_dict = convert_state_dict_to_peft(state_dict)
|
||||||
|
|
||||||
|
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||||
|
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||||
|
rank_key = f"{name}.{module}.lora_B.weight"
|
||||||
|
if rank_key not in state_dict:
|
||||||
|
continue
|
||||||
|
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||||
|
|
||||||
|
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||||
|
for module in ("fc1", "fc2"):
|
||||||
|
rank_key = f"{name}.{module}.lora_B.weight"
|
||||||
|
if rank_key not in state_dict:
|
||||||
|
continue
|
||||||
|
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||||
|
|
||||||
|
if network_alphas is not None:
|
||||||
|
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||||
|
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||||
|
|
||||||
|
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||||
|
|
||||||
|
if "use_dora" in lora_config_kwargs:
|
||||||
|
if lora_config_kwargs["use_dora"]:
|
||||||
|
if is_peft_version("<", "0.9.0"):
|
||||||
|
raise ValueError(
|
||||||
|
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if is_peft_version("<", "0.9.0"):
|
||||||
|
lora_config_kwargs.pop("use_dora")
|
||||||
|
|
||||||
|
if "lora_bias" in lora_config_kwargs:
|
||||||
|
if lora_config_kwargs["lora_bias"]:
|
||||||
|
if is_peft_version("<=", "0.13.2"):
|
||||||
|
raise ValueError(
|
||||||
|
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if is_peft_version("<=", "0.13.2"):
|
||||||
|
lora_config_kwargs.pop("lora_bias")
|
||||||
|
|
||||||
|
lora_config = LoraConfig(**lora_config_kwargs)
|
||||||
|
|
||||||
|
# adapter_name
|
||||||
|
if adapter_name is None:
|
||||||
|
adapter_name = get_adapter_name(text_encoder)
|
||||||
|
|
||||||
|
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||||
|
|
||||||
|
# inject LoRA layers and load the state dict
|
||||||
|
# in transformers we automatically check whether the adapter name is already in use or not
|
||||||
|
text_encoder.load_adapter(
|
||||||
|
adapter_name=adapter_name,
|
||||||
|
adapter_state_dict=state_dict,
|
||||||
|
peft_config=lora_config,
|
||||||
|
**peft_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# scale LoRA layers with `lora_scale`
|
||||||
|
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||||
|
|
||||||
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
|
|
||||||
|
# Offload back.
|
||||||
|
if is_model_cpu_offload:
|
||||||
|
_pipeline.enable_model_cpu_offload()
|
||||||
|
elif is_sequential_cpu_offload:
|
||||||
|
_pipeline.enable_sequential_cpu_offload()
|
||||||
|
# Unsafe code />
|
||||||
|
|
||||||
|
if prefix is not None and not state_dict:
|
||||||
|
logger.warning(
|
||||||
|
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||||
|
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||||
|
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||||
|
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||||
|
"https://github.com/huggingface/diffusers/issues/new"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _func_optionally_disable_offloading(_pipeline):
|
||||||
|
is_model_cpu_offload = False
|
||||||
|
is_sequential_cpu_offload = False
|
||||||
|
|
||||||
|
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||||
|
for _, component in _pipeline.components.items():
|
||||||
|
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||||
|
if not is_model_cpu_offload:
|
||||||
|
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||||
|
if not is_sequential_cpu_offload:
|
||||||
|
is_sequential_cpu_offload = (
|
||||||
|
isinstance(component._hf_hook, AlignDevicesHook)
|
||||||
|
or hasattr(component._hf_hook, "hooks")
|
||||||
|
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||||
|
)
|
||||||
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||||
|
|
||||||
|
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraBaseMixin:
|
||||||
|
"""Utility class for handling LoRAs."""
|
||||||
|
|
||||||
|
_lora_loadable_modules = []
|
||||||
|
num_fused_loras = 0
|
||||||
|
|
||||||
|
def load_lora_weights(self, **kwargs):
|
||||||
|
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_lora_weights(cls, **kwargs):
|
||||||
|
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def lora_state_dict(cls, **kwargs):
|
||||||
|
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _optionally_disable_offloading(cls, _pipeline):
|
||||||
|
"""
|
||||||
|
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
_pipeline (`DiffusionPipeline`):
|
||||||
|
The pipeline to disable offloading for.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple:
|
||||||
|
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||||
|
"""
|
||||||
|
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _fetch_state_dict(cls, *args, **kwargs):
|
||||||
|
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||||
|
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||||
|
return _fetch_state_dict(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||||
|
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||||
|
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||||
|
return _best_guess_weight_name(*args, **kwargs)
|
||||||
|
|
||||||
|
def unload_lora_weights(self):
|
||||||
|
"""
|
||||||
|
Unloads the LoRA parameters.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||||
|
>>> pipeline.unload_lora_weights()
|
||||||
|
>>> ...
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None:
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.unload_lora()
|
||||||
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
|
_remove_text_encoder_monkey_patch(model)
|
||||||
|
|
||||||
|
def fuse_lora(
|
||||||
|
self,
|
||||||
|
components: List[str] = [],
|
||||||
|
lora_scale: float = 1.0,
|
||||||
|
safe_fusing: bool = False,
|
||||||
|
adapter_names: Optional[List[str]] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental API.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||||
|
lora_scale (`float`, defaults to 1.0):
|
||||||
|
Controls how much to influence the outputs with the LoRA parameters.
|
||||||
|
safe_fusing (`bool`, defaults to `False`):
|
||||||
|
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||||
|
adapter_names (`List[str]`, *optional*):
|
||||||
|
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```py
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
import torch
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||||
|
).to("cuda")
|
||||||
|
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||||
|
pipeline.fuse_lora(lora_scale=0.7)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if "fuse_unet" in kwargs:
|
||||||
|
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"fuse_unet",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
if "fuse_transformer" in kwargs:
|
||||||
|
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"fuse_transformer",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
if "fuse_text_encoder" in kwargs:
|
||||||
|
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"fuse_text_encoder",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(components) == 0:
|
||||||
|
raise ValueError("`components` cannot be an empty list.")
|
||||||
|
|
||||||
|
for fuse_component in components:
|
||||||
|
if fuse_component not in self._lora_loadable_modules:
|
||||||
|
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||||
|
|
||||||
|
model = getattr(self, fuse_component, None)
|
||||||
|
if model is not None:
|
||||||
|
# check if diffusers model
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||||
|
# handle transformers models.
|
||||||
|
if issubclass(model.__class__, PreTrainedModel):
|
||||||
|
fuse_text_encoder_lora(
|
||||||
|
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
|
)
|
||||||
|
|
||||||
|
self.num_fused_loras += 1
|
||||||
|
|
||||||
|
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||||
|
r"""
|
||||||
|
Reverses the effect of
|
||||||
|
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||||
|
|
||||||
|
<Tip warning={true}>
|
||||||
|
|
||||||
|
This is an experimental API.
|
||||||
|
|
||||||
|
</Tip>
|
||||||
|
|
||||||
|
Args:
|
||||||
|
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||||
|
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||||
|
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||||
|
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||||
|
LoRA parameters then it won't have any effect.
|
||||||
|
"""
|
||||||
|
if "unfuse_unet" in kwargs:
|
||||||
|
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"unfuse_unet",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
if "unfuse_transformer" in kwargs:
|
||||||
|
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"unfuse_transformer",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
if "unfuse_text_encoder" in kwargs:
|
||||||
|
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
||||||
|
deprecate(
|
||||||
|
"unfuse_text_encoder",
|
||||||
|
"1.0.0",
|
||||||
|
depr_message,
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(components) == 0:
|
||||||
|
raise ValueError("`components` cannot be an empty list.")
|
||||||
|
|
||||||
|
for fuse_component in components:
|
||||||
|
if fuse_component not in self._lora_loadable_modules:
|
||||||
|
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||||
|
|
||||||
|
model = getattr(self, fuse_component, None)
|
||||||
|
if model is not None:
|
||||||
|
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
module.unmerge()
|
||||||
|
|
||||||
|
self.num_fused_loras -= 1
|
||||||
|
|
||||||
|
def set_adapters(
|
||||||
|
self,
|
||||||
|
adapter_names: Union[List[str], str],
|
||||||
|
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||||
|
):
|
||||||
|
if isinstance(adapter_weights, dict):
|
||||||
|
components_passed = set(adapter_weights.keys())
|
||||||
|
lora_components = set(self._lora_loadable_modules)
|
||||||
|
|
||||||
|
invalid_components = sorted(components_passed - lora_components)
|
||||||
|
if invalid_components:
|
||||||
|
logger.warning(
|
||||||
|
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
||||||
|
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
||||||
|
"to the invalid components will be removed and ignored."
|
||||||
|
)
|
||||||
|
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
||||||
|
|
||||||
|
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||||
|
adapter_weights = copy.deepcopy(adapter_weights)
|
||||||
|
|
||||||
|
# Expand weights into a list, one entry per adapter
|
||||||
|
if not isinstance(adapter_weights, list):
|
||||||
|
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||||
|
|
||||||
|
if len(adapter_names) != len(adapter_weights):
|
||||||
|
raise ValueError(
|
||||||
|
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||||
|
# eg ["adapter1", "adapter2"]
|
||||||
|
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||||
|
missing_adapters = set(adapter_names) - all_adapters
|
||||||
|
if len(missing_adapters) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||||
|
invert_list_adapters = {
|
||||||
|
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||||
|
for adapter in all_adapters
|
||||||
|
}
|
||||||
|
|
||||||
|
# Decompose weights into weights for denoiser and text encoders.
|
||||||
|
_component_adapter_weights = {}
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component)
|
||||||
|
|
||||||
|
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||||
|
if isinstance(weights, dict):
|
||||||
|
component_adapter_weights = weights.pop(component, None)
|
||||||
|
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
||||||
|
logger.warning(
|
||||||
|
(
|
||||||
|
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
||||||
|
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
||||||
|
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
component_adapter_weights = weights
|
||||||
|
|
||||||
|
_component_adapter_weights.setdefault(component, [])
|
||||||
|
_component_adapter_weights[component].append(component_adapter_weights)
|
||||||
|
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||||
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
|
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||||
|
|
||||||
|
def disable_lora(self):
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None:
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.disable_lora()
|
||||||
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
|
disable_lora_for_text_encoder(model)
|
||||||
|
|
||||||
|
def enable_lora(self):
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None:
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.enable_lora()
|
||||||
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
|
enable_lora_for_text_encoder(model)
|
||||||
|
|
||||||
|
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||||
|
adapter_names (`Union[List[str], str]`):
|
||||||
|
The names of the adapter to delete. Can be a single string or a list of strings
|
||||||
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
if isinstance(adapter_names, str):
|
||||||
|
adapter_names = [adapter_names]
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None:
|
||||||
|
if issubclass(model.__class__, ModelMixin):
|
||||||
|
model.delete_adapters(adapter_names)
|
||||||
|
elif issubclass(model.__class__, PreTrainedModel):
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
delete_adapter_layers(model, adapter_name)
|
||||||
|
|
||||||
|
def get_active_adapters(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Gets the list of the current active adapters.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from diffusers import DiffusionPipeline
|
||||||
|
|
||||||
|
pipeline = DiffusionPipeline.from_pretrained(
|
||||||
|
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||||
|
).to("cuda")
|
||||||
|
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||||
|
pipeline.get_active_adapters()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError(
|
||||||
|
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||||
|
)
|
||||||
|
|
||||||
|
active_adapters = []
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
active_adapters = module.active_adapters
|
||||||
|
break
|
||||||
|
|
||||||
|
return active_adapters
|
||||||
|
|
||||||
|
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||||
|
"""
|
||||||
|
Gets the current list of all available adapters in the pipeline.
|
||||||
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError(
|
||||||
|
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||||
|
)
|
||||||
|
|
||||||
|
set_adapters = {}
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if (
|
||||||
|
model is not None
|
||||||
|
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||||
|
and hasattr(model, "peft_config")
|
||||||
|
):
|
||||||
|
set_adapters[component] = list(model.peft_config.keys())
|
||||||
|
|
||||||
|
return set_adapters
|
||||||
|
|
||||||
|
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||||
|
"""
|
||||||
|
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||||
|
you want to load multiple adapters and free some GPU memory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adapter_names (`List[str]`):
|
||||||
|
List of adapters to send device to.
|
||||||
|
device (`Union[torch.device, str, int]`):
|
||||||
|
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||||
|
"""
|
||||||
|
if not USE_PEFT_BACKEND:
|
||||||
|
raise ValueError("PEFT backend is required for this method.")
|
||||||
|
|
||||||
|
for component in self._lora_loadable_modules:
|
||||||
|
model = getattr(self, component, None)
|
||||||
|
if model is not None:
|
||||||
|
for module in model.modules():
|
||||||
|
if isinstance(module, BaseTunerLayer):
|
||||||
|
for adapter_name in adapter_names:
|
||||||
|
module.lora_A[adapter_name].to(device)
|
||||||
|
module.lora_B[adapter_name].to(device)
|
||||||
|
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||||
|
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
||||||
|
if adapter_name in module.lora_magnitude_vector:
|
||||||
|
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
||||||
|
adapter_name
|
||||||
|
].to(device)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pack_weights(layers, prefix):
|
||||||
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||||
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
|
return layers_state_dict
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def write_lora_layers(
|
||||||
|
state_dict: Dict[str, torch.Tensor],
|
||||||
|
save_directory: str,
|
||||||
|
is_main_process: bool,
|
||||||
|
weight_name: str,
|
||||||
|
save_function: Callable,
|
||||||
|
safe_serialization: bool,
|
||||||
|
):
|
||||||
|
if os.path.isfile(save_directory):
|
||||||
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||||
|
return
|
||||||
|
|
||||||
|
if save_function is None:
|
||||||
|
if safe_serialization:
|
||||||
|
|
||||||
|
def save_function(weights, filename):
|
||||||
|
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||||
|
|
||||||
|
else:
|
||||||
|
save_function = torch.save
|
||||||
|
|
||||||
|
os.makedirs(save_directory, exist_ok=True)
|
||||||
|
|
||||||
|
if weight_name is None:
|
||||||
|
if safe_serialization:
|
||||||
|
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||||
|
else:
|
||||||
|
weight_name = LORA_WEIGHT_NAME
|
||||||
|
|
||||||
|
save_path = Path(save_directory, weight_name).as_posix()
|
||||||
|
save_function(state_dict, save_path)
|
||||||
|
logger.info(f"Model weights saved in {save_path}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lora_scale(self) -> float:
|
||||||
|
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||||
|
# if _lora_scale has not been set, return 1
|
||||||
|
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||||
|
|
||||||
|
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||||
|
"""Enables the possibility to hotswap LoRA adapters.
|
||||||
|
|
||||||
|
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||||
|
the loaded adapters differ.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target_rank (`int`):
|
||||||
|
The highest rank among all the adapters that will be loaded.
|
||||||
|
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||||
|
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||||
|
options are:
|
||||||
|
- "error" (default): raise an error
|
||||||
|
- "warn": issue a warning
|
||||||
|
- "ignore": do nothing
|
||||||
|
"""
|
||||||
|
for key, component in self.components.items():
|
||||||
|
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||||
|
component.enable_lora_hotswap(**kwargs)
|
||||||
@@ -17,7 +17,7 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import is_peft_version, logging, state_dict_all_zero
|
from ...utils import is_peft_version, logging, state_dict_all_zero
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
5686
src/diffusers/loaders/lora/lora_pipeline.py
Normal file
5686
src/diffusers/loaders/lora/lora_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,924 +12,66 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import copy
|
|
||||||
import inspect
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import safetensors
|
from ..utils import deprecate
|
||||||
import torch
|
from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401
|
||||||
import torch.nn as nn
|
|
||||||
from huggingface_hub import model_info
|
|
||||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
|
||||||
|
|
||||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
|
||||||
from ..utils import (
|
|
||||||
USE_PEFT_BACKEND,
|
|
||||||
_get_model_file,
|
|
||||||
convert_state_dict_to_diffusers,
|
|
||||||
convert_state_dict_to_peft,
|
|
||||||
delete_adapter_layers,
|
|
||||||
deprecate,
|
|
||||||
get_adapter_name,
|
|
||||||
get_peft_kwargs,
|
|
||||||
is_accelerate_available,
|
|
||||||
is_peft_available,
|
|
||||||
is_peft_version,
|
|
||||||
is_transformers_available,
|
|
||||||
is_transformers_version,
|
|
||||||
logging,
|
|
||||||
recurse_remove_peft_layers,
|
|
||||||
scale_lora_layers,
|
|
||||||
set_adapter_layers,
|
|
||||||
set_weights_and_activate_adapters,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
|
||||||
|
|
||||||
if is_peft_available():
|
|
||||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
|
||||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
|
||||||
|
|
||||||
|
|
||||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||||
"""
|
from .lora.lora_base import fuse_text_encoder_lora
|
||||||
Fuses LoRAs for the text encoder.
|
|
||||||
|
|
||||||
Args:
|
deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead."
|
||||||
text_encoder (`torch.nn.Module`):
|
deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message)
|
||||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
|
||||||
attribute.
|
|
||||||
lora_scale (`float`, defaults to 1.0):
|
|
||||||
Controls how much to influence the outputs with the LoRA parameters.
|
|
||||||
safe_fusing (`bool`, defaults to `False`):
|
|
||||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
|
||||||
adapter_names (`List[str]` or `str`):
|
|
||||||
The names of the adapters to use.
|
|
||||||
"""
|
|
||||||
merge_kwargs = {"safe_merge": safe_fusing}
|
|
||||||
|
|
||||||
for module in text_encoder.modules():
|
return fuse_text_encoder_lora(
|
||||||
if isinstance(module, BaseTunerLayer):
|
text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||||
if lora_scale != 1.0:
|
)
|
||||||
module.scale_layer(lora_scale)
|
|
||||||
|
|
||||||
# For BC with previous PEFT versions, we need to check the signature
|
|
||||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
|
||||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
|
||||||
if "adapter_names" in supported_merge_kwargs:
|
|
||||||
merge_kwargs["adapter_names"] = adapter_names
|
|
||||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
|
||||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
|
||||||
)
|
|
||||||
|
|
||||||
module.merge(**merge_kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def unfuse_text_encoder_lora(text_encoder):
|
def unfuse_text_encoder_lora(text_encoder):
|
||||||
"""
|
from .lora.lora_base import unfuse_text_encoder_lora
|
||||||
Unfuses LoRAs for the text encoder.
|
|
||||||
|
|
||||||
Args:
|
deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead."
|
||||||
text_encoder (`torch.nn.Module`):
|
deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message)
|
||||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
|
||||||
attribute.
|
return unfuse_text_encoder_lora(text_encoder)
|
||||||
"""
|
|
||||||
for module in text_encoder.modules():
|
|
||||||
if isinstance(module, BaseTunerLayer):
|
|
||||||
module.unmerge()
|
|
||||||
|
|
||||||
|
|
||||||
def set_adapters_for_text_encoder(
|
def set_adapters_for_text_encoder(
|
||||||
adapter_names: Union[List[str], str],
|
adapter_names,
|
||||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
text_encoder=None,
|
||||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
text_encoder_weights=None,
|
||||||
):
|
):
|
||||||
"""
|
from .lora.lora_base import set_adapters_for_text_encoder
|
||||||
Sets the adapter layers for the text encoder.
|
|
||||||
|
|
||||||
Args:
|
deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead."
|
||||||
adapter_names (`List[str]` or `str`):
|
deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message)
|
||||||
The names of the adapters to use.
|
|
||||||
text_encoder (`torch.nn.Module`, *optional*):
|
|
||||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
|
||||||
attribute.
|
|
||||||
text_encoder_weights (`List[float]`, *optional*):
|
|
||||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
|
||||||
"""
|
|
||||||
if text_encoder is None:
|
|
||||||
raise ValueError(
|
|
||||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
|
||||||
)
|
|
||||||
|
|
||||||
def process_weights(adapter_names, weights):
|
return set_adapters_for_text_encoder(
|
||||||
# Expand weights into a list, one entry per adapter
|
adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights
|
||||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
|
||||||
if not isinstance(weights, list):
|
|
||||||
weights = [weights] * len(adapter_names)
|
|
||||||
|
|
||||||
if len(adapter_names) != len(weights):
|
|
||||||
raise ValueError(
|
|
||||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set None values to default of 1.0
|
|
||||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
|
||||||
weights = [w if w is not None else 1.0 for w in weights]
|
|
||||||
|
|
||||||
return weights
|
|
||||||
|
|
||||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
|
||||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
|
||||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
|
||||||
|
|
||||||
|
|
||||||
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
|
||||||
"""
|
|
||||||
Disables the LoRA layers for the text encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_encoder (`torch.nn.Module`, *optional*):
|
|
||||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
|
||||||
attribute.
|
|
||||||
"""
|
|
||||||
if text_encoder is None:
|
|
||||||
raise ValueError("Text Encoder not found.")
|
|
||||||
set_adapter_layers(text_encoder, enabled=False)
|
|
||||||
|
|
||||||
|
|
||||||
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
|
||||||
"""
|
|
||||||
Enables the LoRA layers for the text encoder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_encoder (`torch.nn.Module`, *optional*):
|
|
||||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
|
||||||
attribute.
|
|
||||||
"""
|
|
||||||
if text_encoder is None:
|
|
||||||
raise ValueError("Text Encoder not found.")
|
|
||||||
set_adapter_layers(text_encoder, enabled=True)
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
|
||||||
recurse_remove_peft_layers(text_encoder)
|
|
||||||
if getattr(text_encoder, "peft_config", None) is not None:
|
|
||||||
del text_encoder.peft_config
|
|
||||||
text_encoder._hf_peft_config_loaded = None
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_state_dict(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weight_name,
|
|
||||||
use_safetensors,
|
|
||||||
local_files_only,
|
|
||||||
cache_dir,
|
|
||||||
force_download,
|
|
||||||
proxies,
|
|
||||||
token,
|
|
||||||
revision,
|
|
||||||
subfolder,
|
|
||||||
user_agent,
|
|
||||||
allow_pickle,
|
|
||||||
):
|
|
||||||
model_file = None
|
|
||||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
|
||||||
# Let's first try to load .safetensors weights
|
|
||||||
if (use_safetensors and weight_name is None) or (
|
|
||||||
weight_name is not None and weight_name.endswith(".safetensors")
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
# Here we're relaxing the loading check to enable more Inference API
|
|
||||||
# friendliness where sometimes, it's not at all possible to automatically
|
|
||||||
# determine `weight_name`.
|
|
||||||
if weight_name is None:
|
|
||||||
weight_name = _best_guess_weight_name(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
file_extension=".safetensors",
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
|
||||||
except (IOError, safetensors.SafetensorError) as e:
|
|
||||||
if not allow_pickle:
|
|
||||||
raise e
|
|
||||||
# try loading non-safetensors weights
|
|
||||||
model_file = None
|
|
||||||
pass
|
|
||||||
|
|
||||||
if model_file is None:
|
|
||||||
if weight_name is None:
|
|
||||||
weight_name = _best_guess_weight_name(
|
|
||||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
|
||||||
)
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path_or_dict,
|
|
||||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
state_dict = load_state_dict(model_file)
|
|
||||||
else:
|
|
||||||
state_dict = pretrained_model_name_or_path_or_dict
|
|
||||||
|
|
||||||
return state_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _best_guess_weight_name(
|
|
||||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
|
||||||
):
|
|
||||||
if local_files_only or HF_HUB_OFFLINE:
|
|
||||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
|
||||||
|
|
||||||
targeted_files = []
|
|
||||||
|
|
||||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
|
||||||
return
|
|
||||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
|
||||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
|
||||||
else:
|
|
||||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
|
||||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
|
||||||
if len(targeted_files) == 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
|
||||||
# "optimizer" does not correspond to a LoRA checkpoint
|
|
||||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
|
||||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
|
||||||
targeted_files = list(
|
|
||||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
|
||||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
|
||||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
|
||||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
|
||||||
|
|
||||||
if len(targeted_files) > 1:
|
def disable_lora_for_text_encoder(text_encoder=None):
|
||||||
raise ValueError(
|
from .lora.lora_base import disable_lora_for_text_encoder
|
||||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
|
||||||
)
|
|
||||||
weight_name = targeted_files[0]
|
|
||||||
return weight_name
|
|
||||||
|
|
||||||
|
deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead."
|
||||||
|
deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||||
|
|
||||||
def _load_lora_into_text_encoder(
|
return disable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||||
state_dict,
|
|
||||||
network_alphas,
|
|
||||||
text_encoder,
|
|
||||||
prefix=None,
|
|
||||||
lora_scale=1.0,
|
|
||||||
text_encoder_name="text_encoder",
|
|
||||||
adapter_name=None,
|
|
||||||
_pipeline=None,
|
|
||||||
low_cpu_mem_usage=False,
|
|
||||||
hotswap: bool = False,
|
|
||||||
):
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
peft_kwargs = {}
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
if not is_peft_version(">=", "0.13.1"):
|
|
||||||
raise ValueError(
|
|
||||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
|
||||||
)
|
|
||||||
if not is_transformers_version(">", "4.45.2"):
|
|
||||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
|
||||||
# https://github.com/huggingface/transformers/pull/33725/
|
|
||||||
raise ValueError(
|
|
||||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
|
||||||
)
|
|
||||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
|
||||||
|
|
||||||
from peft import LoraConfig
|
def enable_lora_for_text_encoder(text_encoder=None):
|
||||||
|
from .lora.lora_base import enable_lora_for_text_encoder
|
||||||
|
|
||||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead."
|
||||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||||
# their prefixes.
|
|
||||||
prefix = text_encoder_name if prefix is None else prefix
|
|
||||||
|
|
||||||
# Safe prefix to check with.
|
return enable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
|
||||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
|
||||||
|
|
||||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
|
||||||
if prefix is not None:
|
|
||||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
|
||||||
|
|
||||||
if len(state_dict) > 0:
|
class LoraBaseMixin(LoraBaseMixin):
|
||||||
logger.info(f"Loading {prefix}.")
|
def __init__(self, *args, **kwargs):
|
||||||
rank = {}
|
deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead."
|
||||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
# convert state dict
|
|
||||||
state_dict = convert_state_dict_to_peft(state_dict)
|
|
||||||
|
|
||||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
|
||||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
|
||||||
rank_key = f"{name}.{module}.lora_B.weight"
|
|
||||||
if rank_key not in state_dict:
|
|
||||||
continue
|
|
||||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
|
||||||
|
|
||||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
|
||||||
for module in ("fc1", "fc2"):
|
|
||||||
rank_key = f"{name}.{module}.lora_B.weight"
|
|
||||||
if rank_key not in state_dict:
|
|
||||||
continue
|
|
||||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
|
||||||
|
|
||||||
if network_alphas is not None:
|
|
||||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
|
||||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
|
||||||
|
|
||||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
|
||||||
|
|
||||||
if "use_dora" in lora_config_kwargs:
|
|
||||||
if lora_config_kwargs["use_dora"]:
|
|
||||||
if is_peft_version("<", "0.9.0"):
|
|
||||||
raise ValueError(
|
|
||||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if is_peft_version("<", "0.9.0"):
|
|
||||||
lora_config_kwargs.pop("use_dora")
|
|
||||||
|
|
||||||
if "lora_bias" in lora_config_kwargs:
|
|
||||||
if lora_config_kwargs["lora_bias"]:
|
|
||||||
if is_peft_version("<=", "0.13.2"):
|
|
||||||
raise ValueError(
|
|
||||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if is_peft_version("<=", "0.13.2"):
|
|
||||||
lora_config_kwargs.pop("lora_bias")
|
|
||||||
|
|
||||||
lora_config = LoraConfig(**lora_config_kwargs)
|
|
||||||
|
|
||||||
# adapter_name
|
|
||||||
if adapter_name is None:
|
|
||||||
adapter_name = get_adapter_name(text_encoder)
|
|
||||||
|
|
||||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
|
||||||
|
|
||||||
# inject LoRA layers and load the state dict
|
|
||||||
# in transformers we automatically check whether the adapter name is already in use or not
|
|
||||||
text_encoder.load_adapter(
|
|
||||||
adapter_name=adapter_name,
|
|
||||||
adapter_state_dict=state_dict,
|
|
||||||
peft_config=lora_config,
|
|
||||||
**peft_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# scale LoRA layers with `lora_scale`
|
|
||||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
|
||||||
|
|
||||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
|
||||||
|
|
||||||
# Offload back.
|
|
||||||
if is_model_cpu_offload:
|
|
||||||
_pipeline.enable_model_cpu_offload()
|
|
||||||
elif is_sequential_cpu_offload:
|
|
||||||
_pipeline.enable_sequential_cpu_offload()
|
|
||||||
# Unsafe code />
|
|
||||||
|
|
||||||
if prefix is not None and not state_dict:
|
|
||||||
logger.warning(
|
|
||||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
|
||||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
|
||||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
|
||||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
|
||||||
"https://github.com/huggingface/diffusers/issues/new"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _func_optionally_disable_offloading(_pipeline):
|
|
||||||
is_model_cpu_offload = False
|
|
||||||
is_sequential_cpu_offload = False
|
|
||||||
|
|
||||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
|
||||||
for _, component in _pipeline.components.items():
|
|
||||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
|
||||||
if not is_model_cpu_offload:
|
|
||||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
|
||||||
if not is_sequential_cpu_offload:
|
|
||||||
is_sequential_cpu_offload = (
|
|
||||||
isinstance(component._hf_hook, AlignDevicesHook)
|
|
||||||
or hasattr(component._hf_hook, "hooks")
|
|
||||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
||||||
)
|
|
||||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
|
||||||
|
|
||||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
|
||||||
|
|
||||||
|
|
||||||
class LoraBaseMixin:
|
|
||||||
"""Utility class for handling LoRAs."""
|
|
||||||
|
|
||||||
_lora_loadable_modules = []
|
|
||||||
num_fused_loras = 0
|
|
||||||
|
|
||||||
def load_lora_weights(self, **kwargs):
|
|
||||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def save_lora_weights(cls, **kwargs):
|
|
||||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def lora_state_dict(cls, **kwargs):
|
|
||||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _optionally_disable_offloading(cls, _pipeline):
|
|
||||||
"""
|
|
||||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
_pipeline (`DiffusionPipeline`):
|
|
||||||
The pipeline to disable offloading for.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple:
|
|
||||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
|
||||||
"""
|
|
||||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _fetch_state_dict(cls, *args, **kwargs):
|
|
||||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
|
||||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
|
||||||
return _fetch_state_dict(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
|
||||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
|
||||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
|
||||||
return _best_guess_weight_name(*args, **kwargs)
|
|
||||||
|
|
||||||
def unload_lora_weights(self):
|
|
||||||
"""
|
|
||||||
Unloads the LoRA parameters.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```python
|
|
||||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
|
||||||
>>> pipeline.unload_lora_weights()
|
|
||||||
>>> ...
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None:
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.unload_lora()
|
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
|
||||||
_remove_text_encoder_monkey_patch(model)
|
|
||||||
|
|
||||||
def fuse_lora(
|
|
||||||
self,
|
|
||||||
components: List[str] = [],
|
|
||||||
lora_scale: float = 1.0,
|
|
||||||
safe_fusing: bool = False,
|
|
||||||
adapter_names: Optional[List[str]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This is an experimental API.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
|
||||||
lora_scale (`float`, defaults to 1.0):
|
|
||||||
Controls how much to influence the outputs with the LoRA parameters.
|
|
||||||
safe_fusing (`bool`, defaults to `False`):
|
|
||||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
|
||||||
adapter_names (`List[str]`, *optional*):
|
|
||||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```py
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
import torch
|
|
||||||
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
|
||||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
|
||||||
).to("cuda")
|
|
||||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
|
||||||
pipeline.fuse_lora(lora_scale=0.7)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if "fuse_unet" in kwargs:
|
|
||||||
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"fuse_unet",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
if "fuse_transformer" in kwargs:
|
|
||||||
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"fuse_transformer",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
if "fuse_text_encoder" in kwargs:
|
|
||||||
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"fuse_text_encoder",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(components) == 0:
|
|
||||||
raise ValueError("`components` cannot be an empty list.")
|
|
||||||
|
|
||||||
for fuse_component in components:
|
|
||||||
if fuse_component not in self._lora_loadable_modules:
|
|
||||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
|
||||||
|
|
||||||
model = getattr(self, fuse_component, None)
|
|
||||||
if model is not None:
|
|
||||||
# check if diffusers model
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
|
||||||
# handle transformers models.
|
|
||||||
if issubclass(model.__class__, PreTrainedModel):
|
|
||||||
fuse_text_encoder_lora(
|
|
||||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
|
||||||
)
|
|
||||||
|
|
||||||
self.num_fused_loras += 1
|
|
||||||
|
|
||||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
|
||||||
r"""
|
|
||||||
Reverses the effect of
|
|
||||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
|
||||||
|
|
||||||
<Tip warning={true}>
|
|
||||||
|
|
||||||
This is an experimental API.
|
|
||||||
|
|
||||||
</Tip>
|
|
||||||
|
|
||||||
Args:
|
|
||||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
|
||||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
|
||||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
|
||||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
|
||||||
LoRA parameters then it won't have any effect.
|
|
||||||
"""
|
|
||||||
if "unfuse_unet" in kwargs:
|
|
||||||
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"unfuse_unet",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
if "unfuse_transformer" in kwargs:
|
|
||||||
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"unfuse_transformer",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
if "unfuse_text_encoder" in kwargs:
|
|
||||||
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
|
||||||
deprecate(
|
|
||||||
"unfuse_text_encoder",
|
|
||||||
"1.0.0",
|
|
||||||
depr_message,
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(components) == 0:
|
|
||||||
raise ValueError("`components` cannot be an empty list.")
|
|
||||||
|
|
||||||
for fuse_component in components:
|
|
||||||
if fuse_component not in self._lora_loadable_modules:
|
|
||||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
|
||||||
|
|
||||||
model = getattr(self, fuse_component, None)
|
|
||||||
if model is not None:
|
|
||||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, BaseTunerLayer):
|
|
||||||
module.unmerge()
|
|
||||||
|
|
||||||
self.num_fused_loras -= 1
|
|
||||||
|
|
||||||
def set_adapters(
|
|
||||||
self,
|
|
||||||
adapter_names: Union[List[str], str],
|
|
||||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
|
||||||
):
|
|
||||||
if isinstance(adapter_weights, dict):
|
|
||||||
components_passed = set(adapter_weights.keys())
|
|
||||||
lora_components = set(self._lora_loadable_modules)
|
|
||||||
|
|
||||||
invalid_components = sorted(components_passed - lora_components)
|
|
||||||
if invalid_components:
|
|
||||||
logger.warning(
|
|
||||||
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
|
||||||
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
|
||||||
"to the invalid components will be removed and ignored."
|
|
||||||
)
|
|
||||||
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
|
||||||
|
|
||||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
|
||||||
adapter_weights = copy.deepcopy(adapter_weights)
|
|
||||||
|
|
||||||
# Expand weights into a list, one entry per adapter
|
|
||||||
if not isinstance(adapter_weights, list):
|
|
||||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
|
||||||
|
|
||||||
if len(adapter_names) != len(adapter_weights):
|
|
||||||
raise ValueError(
|
|
||||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
|
||||||
# eg ["adapter1", "adapter2"]
|
|
||||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
|
||||||
missing_adapters = set(adapter_names) - all_adapters
|
|
||||||
if len(missing_adapters) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
|
||||||
)
|
|
||||||
|
|
||||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
|
||||||
invert_list_adapters = {
|
|
||||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
|
||||||
for adapter in all_adapters
|
|
||||||
}
|
|
||||||
|
|
||||||
# Decompose weights into weights for denoiser and text encoders.
|
|
||||||
_component_adapter_weights = {}
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component)
|
|
||||||
|
|
||||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
|
||||||
if isinstance(weights, dict):
|
|
||||||
component_adapter_weights = weights.pop(component, None)
|
|
||||||
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
|
||||||
logger.warning(
|
|
||||||
(
|
|
||||||
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
|
||||||
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
|
||||||
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
component_adapter_weights = weights
|
|
||||||
|
|
||||||
_component_adapter_weights.setdefault(component, [])
|
|
||||||
_component_adapter_weights[component].append(component_adapter_weights)
|
|
||||||
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
|
||||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
|
||||||
|
|
||||||
def disable_lora(self):
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None:
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.disable_lora()
|
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
|
||||||
disable_lora_for_text_encoder(model)
|
|
||||||
|
|
||||||
def enable_lora(self):
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None:
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.enable_lora()
|
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
|
||||||
enable_lora_for_text_encoder(model)
|
|
||||||
|
|
||||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
|
||||||
adapter_names (`Union[List[str], str]`):
|
|
||||||
The names of the adapter to delete. Can be a single string or a list of strings
|
|
||||||
"""
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
if isinstance(adapter_names, str):
|
|
||||||
adapter_names = [adapter_names]
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None:
|
|
||||||
if issubclass(model.__class__, ModelMixin):
|
|
||||||
model.delete_adapters(adapter_names)
|
|
||||||
elif issubclass(model.__class__, PreTrainedModel):
|
|
||||||
for adapter_name in adapter_names:
|
|
||||||
delete_adapter_layers(model, adapter_name)
|
|
||||||
|
|
||||||
def get_active_adapters(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Gets the list of the current active adapters.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```python
|
|
||||||
from diffusers import DiffusionPipeline
|
|
||||||
|
|
||||||
pipeline = DiffusionPipeline.from_pretrained(
|
|
||||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
||||||
).to("cuda")
|
|
||||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
|
||||||
pipeline.get_active_adapters()
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError(
|
|
||||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
|
||||||
)
|
|
||||||
|
|
||||||
active_adapters = []
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, BaseTunerLayer):
|
|
||||||
active_adapters = module.active_adapters
|
|
||||||
break
|
|
||||||
|
|
||||||
return active_adapters
|
|
||||||
|
|
||||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
|
||||||
"""
|
|
||||||
Gets the current list of all available adapters in the pipeline.
|
|
||||||
"""
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError(
|
|
||||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
|
||||||
)
|
|
||||||
|
|
||||||
set_adapters = {}
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if (
|
|
||||||
model is not None
|
|
||||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
|
||||||
and hasattr(model, "peft_config")
|
|
||||||
):
|
|
||||||
set_adapters[component] = list(model.peft_config.keys())
|
|
||||||
|
|
||||||
return set_adapters
|
|
||||||
|
|
||||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
|
||||||
"""
|
|
||||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
|
||||||
you want to load multiple adapters and free some GPU memory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
adapter_names (`List[str]`):
|
|
||||||
List of adapters to send device to.
|
|
||||||
device (`Union[torch.device, str, int]`):
|
|
||||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
|
||||||
"""
|
|
||||||
if not USE_PEFT_BACKEND:
|
|
||||||
raise ValueError("PEFT backend is required for this method.")
|
|
||||||
|
|
||||||
for component in self._lora_loadable_modules:
|
|
||||||
model = getattr(self, component, None)
|
|
||||||
if model is not None:
|
|
||||||
for module in model.modules():
|
|
||||||
if isinstance(module, BaseTunerLayer):
|
|
||||||
for adapter_name in adapter_names:
|
|
||||||
module.lora_A[adapter_name].to(device)
|
|
||||||
module.lora_B[adapter_name].to(device)
|
|
||||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
|
||||||
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
|
||||||
if adapter_name in module.lora_magnitude_vector:
|
|
||||||
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
|
||||||
adapter_name
|
|
||||||
].to(device)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def pack_weights(layers, prefix):
|
|
||||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
|
||||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
|
||||||
return layers_state_dict
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def write_lora_layers(
|
|
||||||
state_dict: Dict[str, torch.Tensor],
|
|
||||||
save_directory: str,
|
|
||||||
is_main_process: bool,
|
|
||||||
weight_name: str,
|
|
||||||
save_function: Callable,
|
|
||||||
safe_serialization: bool,
|
|
||||||
):
|
|
||||||
if os.path.isfile(save_directory):
|
|
||||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
|
||||||
return
|
|
||||||
|
|
||||||
if save_function is None:
|
|
||||||
if safe_serialization:
|
|
||||||
|
|
||||||
def save_function(weights, filename):
|
|
||||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
|
||||||
|
|
||||||
else:
|
|
||||||
save_function = torch.save
|
|
||||||
|
|
||||||
os.makedirs(save_directory, exist_ok=True)
|
|
||||||
|
|
||||||
if weight_name is None:
|
|
||||||
if safe_serialization:
|
|
||||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
|
||||||
else:
|
|
||||||
weight_name = LORA_WEIGHT_NAME
|
|
||||||
|
|
||||||
save_path = Path(save_directory, weight_name).as_posix()
|
|
||||||
save_function(state_dict, save_path)
|
|
||||||
logger.info(f"Model weights saved in {save_path}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def lora_scale(self) -> float:
|
|
||||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
|
||||||
# if _lora_scale has not been set, return 1
|
|
||||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
|
||||||
|
|
||||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
|
||||||
"""Enables the possibility to hotswap LoRA adapters.
|
|
||||||
|
|
||||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
|
||||||
the loaded adapters differ.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_rank (`int`):
|
|
||||||
The highest rank among all the adapters that will be loaded.
|
|
||||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
|
||||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
|
||||||
options are:
|
|
||||||
- "error" (default): raise an error
|
|
||||||
- "warn": issue a warning
|
|
||||||
- "ignore": do nothing
|
|
||||||
"""
|
|
||||||
for key, component in self.components.items():
|
|
||||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
|
||||||
component.enable_lora_hotswap(**kwargs)
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -35,8 +35,8 @@ from ..utils import (
|
|||||||
set_adapter_layers,
|
set_adapter_layers,
|
||||||
set_weights_and_activate_adapters,
|
set_weights_and_activate_adapters,
|
||||||
)
|
)
|
||||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
from .unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -99,7 +99,7 @@ class PeftAdapterMixin:
|
|||||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||||
def _optionally_disable_offloading(cls, _pipeline):
|
def _optionally_disable_offloading(cls, _pipeline):
|
||||||
"""
|
"""
|
||||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||||
|
|||||||
@@ -11,42 +11,8 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
from ..utils import deprecate
|
||||||
import inspect
|
from .single_file.single_file import FromSingleFileMixin
|
||||||
import os
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
|
||||||
from packaging import version
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from ..utils import deprecate, is_transformers_available, logging
|
|
||||||
from .single_file_utils import (
|
|
||||||
SingleFileComponentError,
|
|
||||||
_is_legacy_scheduler_kwargs,
|
|
||||||
_is_model_weights_in_cached_folder,
|
|
||||||
_legacy_load_clip_tokenizer,
|
|
||||||
_legacy_load_safety_checker,
|
|
||||||
_legacy_load_scheduler,
|
|
||||||
create_diffusers_clip_model_from_ldm,
|
|
||||||
create_diffusers_t5_model_from_checkpoint,
|
|
||||||
fetch_diffusers_config,
|
|
||||||
fetch_original_config,
|
|
||||||
is_clip_model_in_single_file,
|
|
||||||
is_t5_in_single_file,
|
|
||||||
load_single_file_checkpoint,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
|
||||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
import transformers
|
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
|
||||||
|
|
||||||
|
|
||||||
def load_single_file_sub_model(
|
def load_single_file_sub_model(
|
||||||
@@ -64,502 +30,30 @@ def load_single_file_sub_model(
|
|||||||
disable_mmap=False,
|
disable_mmap=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
if is_pipeline_module:
|
from .single_file.single_file import load_single_file_sub_model
|
||||||
pipeline_module = getattr(pipelines, library_name)
|
|
||||||
class_obj = getattr(pipeline_module, class_name)
|
|
||||||
else:
|
|
||||||
# else we just import it from the library.
|
|
||||||
library = importlib.import_module(library_name)
|
|
||||||
class_obj = getattr(library, class_name)
|
|
||||||
|
|
||||||
if is_transformers_available():
|
deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead."
|
||||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message)
|
||||||
else:
|
|
||||||
transformers_version = "N/A"
|
|
||||||
|
|
||||||
is_transformers_model = (
|
return load_single_file_sub_model(
|
||||||
is_transformers_available()
|
library_name,
|
||||||
and issubclass(class_obj, PreTrainedModel)
|
class_name,
|
||||||
and transformers_version >= version.parse("4.20.0")
|
name,
|
||||||
)
|
checkpoint,
|
||||||
is_tokenizer = (
|
pipelines,
|
||||||
is_transformers_available()
|
is_pipeline_module,
|
||||||
and issubclass(class_obj, PreTrainedTokenizer)
|
cached_model_config_path,
|
||||||
and transformers_version >= version.parse("4.20.0")
|
original_config,
|
||||||
|
local_files_only,
|
||||||
|
torch_dtype,
|
||||||
|
is_legacy_loading,
|
||||||
|
disable_mmap,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
|
||||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
|
||||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
|
||||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
|
||||||
|
|
||||||
if is_diffusers_single_file_model:
|
class FromSingleFileMixin(FromSingleFileMixin):
|
||||||
load_method = getattr(class_obj, "from_single_file")
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead."
|
||||||
# We cannot provide two different config options to the `from_single_file` method
|
deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message)
|
||||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
super().__init__(*args, **kwargs)
|
||||||
if original_config:
|
|
||||||
cached_model_config_path = None
|
|
||||||
|
|
||||||
loaded_sub_model = load_method(
|
|
||||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
|
||||||
original_config=original_config,
|
|
||||||
config=cached_model_config_path,
|
|
||||||
subfolder=name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
disable_mmap=disable_mmap,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
|
||||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
|
||||||
class_obj,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
config=cached_model_config_path,
|
|
||||||
subfolder=name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
is_legacy_loading=is_legacy_loading,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
|
||||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
|
||||||
class_obj,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
config=cached_model_config_path,
|
|
||||||
subfolder=name,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif is_tokenizer and is_legacy_loading:
|
|
||||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
|
||||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
|
||||||
)
|
|
||||||
|
|
||||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
|
||||||
loaded_sub_model = _legacy_load_scheduler(
|
|
||||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
if not hasattr(class_obj, "from_pretrained"):
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
|
||||||
" a supported loading method."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
loading_kwargs = {}
|
|
||||||
loading_kwargs.update(
|
|
||||||
{
|
|
||||||
"pretrained_model_name_or_path": cached_model_config_path,
|
|
||||||
"subfolder": name,
|
|
||||||
"local_files_only": local_files_only,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
|
||||||
# Skip passing it to those objects
|
|
||||||
if issubclass(class_obj, torch.nn.Module):
|
|
||||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
|
||||||
|
|
||||||
if is_diffusers_model or is_transformers_model:
|
|
||||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
|
||||||
raise SingleFileComponentError(
|
|
||||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
load_method = getattr(class_obj, "from_pretrained")
|
|
||||||
loaded_sub_model = load_method(**loading_kwargs)
|
|
||||||
|
|
||||||
return loaded_sub_model
|
|
||||||
|
|
||||||
|
|
||||||
def _map_component_types_to_config_dict(component_types):
|
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
|
||||||
config_dict = {}
|
|
||||||
component_types.pop("self", None)
|
|
||||||
|
|
||||||
if is_transformers_available():
|
|
||||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
|
||||||
else:
|
|
||||||
transformers_version = "N/A"
|
|
||||||
|
|
||||||
for component_name, component_value in component_types.items():
|
|
||||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
|
||||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
|
||||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
|
||||||
|
|
||||||
is_transformers_model = (
|
|
||||||
is_transformers_available()
|
|
||||||
and issubclass(component_value[0], PreTrainedModel)
|
|
||||||
and transformers_version >= version.parse("4.20.0")
|
|
||||||
)
|
|
||||||
is_transformers_tokenizer = (
|
|
||||||
is_transformers_available()
|
|
||||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
|
||||||
and transformers_version >= version.parse("4.20.0")
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
|
||||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
|
||||||
|
|
||||||
elif is_scheduler_enum or is_scheduler:
|
|
||||||
if is_scheduler_enum:
|
|
||||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
|
||||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
|
||||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
|
||||||
|
|
||||||
elif is_scheduler:
|
|
||||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
|
||||||
|
|
||||||
elif (
|
|
||||||
is_transformers_model or is_transformers_tokenizer
|
|
||||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
|
||||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
|
||||||
|
|
||||||
else:
|
|
||||||
config_dict[component_name] = [None, None]
|
|
||||||
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_pipeline_config_dict(pipeline_class):
|
|
||||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
|
||||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
|
||||||
component_types = pipeline_class._get_signature_types()
|
|
||||||
|
|
||||||
# Ignore parameters that are not required for the pipeline
|
|
||||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
|
||||||
config_dict = _map_component_types_to_config_dict(component_types)
|
|
||||||
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
|
|
||||||
def _download_diffusers_model_config_from_hub(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
cache_dir,
|
|
||||||
revision,
|
|
||||||
proxies,
|
|
||||||
force_download=None,
|
|
||||||
local_files_only=None,
|
|
||||||
token=None,
|
|
||||||
):
|
|
||||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
|
||||||
cached_model_path = snapshot_download(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
revision=revision,
|
|
||||||
proxies=proxies,
|
|
||||||
force_download=force_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
allow_patterns=allow_patterns,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cached_model_path
|
|
||||||
|
|
||||||
|
|
||||||
class FromSingleFileMixin:
|
|
||||||
"""
|
|
||||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@validate_hf_hub_args
|
|
||||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
|
||||||
r"""
|
|
||||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
|
||||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
|
||||||
Can be either:
|
|
||||||
- A link to the `.ckpt` file (for example
|
|
||||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
|
||||||
- A path to a *file* containing all pipeline weights.
|
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
|
||||||
Override the default `torch.dtype` and load the model with another dtype.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
||||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
||||||
is not used.
|
|
||||||
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
|
||||||
won't be downloaded from the Hub.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
||||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
||||||
allowed by Git.
|
|
||||||
original_config_file (`str`, *optional*):
|
|
||||||
The path to the original config file that was used to train the model. If not provided, the config file
|
|
||||||
will be inferred from the checkpoint file.
|
|
||||||
config (`str`, *optional*):
|
|
||||||
Can be either:
|
|
||||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
|
||||||
hosted on the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
|
||||||
component configs in Diffusers format.
|
|
||||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
|
||||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
|
||||||
is on a network mount or hard drive.
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
|
||||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
|
||||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
|
||||||
below for more information.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
```py
|
|
||||||
>>> from diffusers import StableDiffusionPipeline
|
|
||||||
|
|
||||||
>>> # Download pipeline from huggingface.co and cache.
|
|
||||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
|
||||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
|
||||||
... )
|
|
||||||
|
|
||||||
>>> # Download pipeline from local file
|
|
||||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
|
||||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
|
||||||
|
|
||||||
>>> # Enable float16 and move to GPU
|
|
||||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
|
||||||
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
|
||||||
... torch_dtype=torch.float16,
|
|
||||||
... )
|
|
||||||
>>> pipeline.to("cuda")
|
|
||||||
```
|
|
||||||
|
|
||||||
"""
|
|
||||||
original_config_file = kwargs.pop("original_config_file", None)
|
|
||||||
config = kwargs.pop("config", None)
|
|
||||||
original_config = kwargs.pop("original_config", None)
|
|
||||||
|
|
||||||
if original_config_file is not None:
|
|
||||||
deprecation_message = (
|
|
||||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
|
||||||
"please use the `original_config` argument instead."
|
|
||||||
)
|
|
||||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
|
||||||
original_config = original_config_file
|
|
||||||
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", False)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
||||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
|
||||||
|
|
||||||
is_legacy_loading = False
|
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
|
||||||
torch_dtype = torch.float32
|
|
||||||
logger.warning(
|
|
||||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
|
||||||
# These model kwargs should be deprecated
|
|
||||||
scaling_factor = kwargs.get("scaling_factor", None)
|
|
||||||
if scaling_factor is not None:
|
|
||||||
deprecation_message = (
|
|
||||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
|
||||||
"and will be ignored in future versions."
|
|
||||||
)
|
|
||||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
|
||||||
|
|
||||||
if original_config is not None:
|
|
||||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
|
||||||
|
|
||||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
|
||||||
|
|
||||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
|
||||||
|
|
||||||
checkpoint = load_single_file_checkpoint(
|
|
||||||
pretrained_model_link_or_path,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
token=token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
revision=revision,
|
|
||||||
disable_mmap=disable_mmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config is None:
|
|
||||||
config = fetch_diffusers_config(checkpoint)
|
|
||||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
|
||||||
else:
|
|
||||||
default_pretrained_model_config_name = config
|
|
||||||
|
|
||||||
if not os.path.isdir(default_pretrained_model_config_name):
|
|
||||||
# Provided config is a repo_id
|
|
||||||
if default_pretrained_model_config_name.count("/") > 1:
|
|
||||||
raise ValueError(
|
|
||||||
f'The provided config "{config}"'
|
|
||||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# Attempt to download the config files for the pipeline
|
|
||||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
|
||||||
default_pretrained_model_config_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
revision=revision,
|
|
||||||
proxies=proxies,
|
|
||||||
force_download=force_download,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
|
||||||
|
|
||||||
except LocalEntryNotFoundError:
|
|
||||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
|
||||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
|
||||||
# to fetch the config files from the hub so that we have a way
|
|
||||||
# to configure the pipeline components.
|
|
||||||
|
|
||||||
if original_config is None:
|
|
||||||
logger.warning(
|
|
||||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
|
||||||
"Attempting to download the necessary config files for this pipeline.\n"
|
|
||||||
)
|
|
||||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
|
||||||
default_pretrained_model_config_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
revision=revision,
|
|
||||||
proxies=proxies,
|
|
||||||
force_download=force_download,
|
|
||||||
local_files_only=False,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
|
||||||
|
|
||||||
else:
|
|
||||||
# For backwards compatibility
|
|
||||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
|
||||||
logger.warning(
|
|
||||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
|
||||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
|
||||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
|
||||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
|
||||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
|
||||||
"the necessary config files.\n"
|
|
||||||
)
|
|
||||||
is_legacy_loading = True
|
|
||||||
cached_model_config_path = None
|
|
||||||
|
|
||||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
|
||||||
config_dict["_class_name"] = pipeline_class.__name__
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Provided config is a path to a local directory attempt to load directly.
|
|
||||||
cached_model_config_path = default_pretrained_model_config_name
|
|
||||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
|
||||||
|
|
||||||
# pop out "_ignore_files" as it is only needed for download
|
|
||||||
config_dict.pop("_ignore_files", None)
|
|
||||||
|
|
||||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
|
||||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
|
||||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
|
||||||
|
|
||||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
|
||||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
|
||||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
|
||||||
|
|
||||||
from diffusers import pipelines
|
|
||||||
|
|
||||||
# remove `null` components
|
|
||||||
def load_module(name, value):
|
|
||||||
if value[0] is None:
|
|
||||||
return False
|
|
||||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
|
||||||
return False
|
|
||||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
|
||||||
|
|
||||||
for name, (library_name, class_name) in logging.tqdm(
|
|
||||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
|
||||||
):
|
|
||||||
loaded_sub_model = None
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
|
||||||
|
|
||||||
if name in passed_class_obj:
|
|
||||||
loaded_sub_model = passed_class_obj[name]
|
|
||||||
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
loaded_sub_model = load_single_file_sub_model(
|
|
||||||
library_name=library_name,
|
|
||||||
class_name=class_name,
|
|
||||||
name=name,
|
|
||||||
checkpoint=checkpoint,
|
|
||||||
is_pipeline_module=is_pipeline_module,
|
|
||||||
cached_model_config_path=cached_model_config_path,
|
|
||||||
pipelines=pipelines,
|
|
||||||
torch_dtype=torch_dtype,
|
|
||||||
original_config=original_config,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
is_legacy_loading=is_legacy_loading,
|
|
||||||
disable_mmap=disable_mmap,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
except SingleFileComponentError as e:
|
|
||||||
raise SingleFileComponentError(
|
|
||||||
(
|
|
||||||
f"{e.message}\n"
|
|
||||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
|
||||||
f"\n"
|
|
||||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
|
||||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
|
||||||
f"\n"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model
|
|
||||||
|
|
||||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
|
||||||
passed_modules = list(passed_class_obj.keys())
|
|
||||||
optional_modules = pipeline_class._optional_components
|
|
||||||
|
|
||||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
|
||||||
for module in missing_modules:
|
|
||||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
|
||||||
elif len(missing_modules) > 0:
|
|
||||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
|
||||||
raise ValueError(
|
|
||||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
|
||||||
)
|
|
||||||
|
|
||||||
# deprecated kwargs
|
|
||||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
|
||||||
if load_safety_checker is not None:
|
|
||||||
deprecation_message = (
|
|
||||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
|
||||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
|
||||||
)
|
|
||||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
|
||||||
|
|
||||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
|
||||||
init_kwargs.update(safety_checker_components)
|
|
||||||
|
|
||||||
pipe = pipeline_class(**init_kwargs)
|
|
||||||
|
|
||||||
return pipe
|
|
||||||
|
|||||||
8
src/diffusers/loaders/single_file/__init__.py
Normal file
8
src/diffusers/loaders/single_file/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
from ...utils import is_torch_available, is_transformers_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .single_file_model import FromOriginalModelMixin
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
from .single_file import FromSingleFileMixin
|
||||||
565
src/diffusers/loaders/single_file/single_file.py
Normal file
565
src/diffusers/loaders/single_file/single_file.py
Normal file
@@ -0,0 +1,565 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||||
|
from packaging import version
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ...utils import deprecate, is_transformers_available, logging
|
||||||
|
from .single_file_utils import (
|
||||||
|
SingleFileComponentError,
|
||||||
|
_is_legacy_scheduler_kwargs,
|
||||||
|
_is_model_weights_in_cached_folder,
|
||||||
|
_legacy_load_clip_tokenizer,
|
||||||
|
_legacy_load_safety_checker,
|
||||||
|
_legacy_load_scheduler,
|
||||||
|
create_diffusers_clip_model_from_ldm,
|
||||||
|
create_diffusers_t5_model_from_checkpoint,
|
||||||
|
fetch_diffusers_config,
|
||||||
|
fetch_original_config,
|
||||||
|
is_clip_model_in_single_file,
|
||||||
|
is_t5_in_single_file,
|
||||||
|
load_single_file_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||||
|
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
import transformers
|
||||||
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_single_file_sub_model(
|
||||||
|
library_name,
|
||||||
|
class_name,
|
||||||
|
name,
|
||||||
|
checkpoint,
|
||||||
|
pipelines,
|
||||||
|
is_pipeline_module,
|
||||||
|
cached_model_config_path,
|
||||||
|
original_config=None,
|
||||||
|
local_files_only=False,
|
||||||
|
torch_dtype=None,
|
||||||
|
is_legacy_loading=False,
|
||||||
|
disable_mmap=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if is_pipeline_module:
|
||||||
|
pipeline_module = getattr(pipelines, library_name)
|
||||||
|
class_obj = getattr(pipeline_module, class_name)
|
||||||
|
else:
|
||||||
|
# else we just import it from the library.
|
||||||
|
library = importlib.import_module(library_name)
|
||||||
|
class_obj = getattr(library, class_name)
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||||
|
else:
|
||||||
|
transformers_version = "N/A"
|
||||||
|
|
||||||
|
is_transformers_model = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(class_obj, PreTrainedModel)
|
||||||
|
and transformers_version >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
is_tokenizer = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(class_obj, PreTrainedTokenizer)
|
||||||
|
and transformers_version >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
|
||||||
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
|
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||||
|
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||||
|
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||||
|
|
||||||
|
if is_diffusers_single_file_model:
|
||||||
|
load_method = getattr(class_obj, "from_single_file")
|
||||||
|
|
||||||
|
# We cannot provide two different config options to the `from_single_file` method
|
||||||
|
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||||
|
if original_config:
|
||||||
|
cached_model_config_path = None
|
||||||
|
|
||||||
|
loaded_sub_model = load_method(
|
||||||
|
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||||
|
original_config=original_config,
|
||||||
|
config=cached_model_config_path,
|
||||||
|
subfolder=name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||||
|
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||||
|
class_obj,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
config=cached_model_config_path,
|
||||||
|
subfolder=name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
is_legacy_loading=is_legacy_loading,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||||
|
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||||
|
class_obj,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
config=cached_model_config_path,
|
||||||
|
subfolder=name,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif is_tokenizer and is_legacy_loading:
|
||||||
|
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||||
|
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||||
|
)
|
||||||
|
|
||||||
|
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||||
|
loaded_sub_model = _legacy_load_scheduler(
|
||||||
|
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not hasattr(class_obj, "from_pretrained"):
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||||
|
" a supported loading method."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_kwargs = {}
|
||||||
|
loading_kwargs.update(
|
||||||
|
{
|
||||||
|
"pretrained_model_name_or_path": cached_model_config_path,
|
||||||
|
"subfolder": name,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||||
|
# Skip passing it to those objects
|
||||||
|
if issubclass(class_obj, torch.nn.Module):
|
||||||
|
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||||
|
|
||||||
|
if is_diffusers_model or is_transformers_model:
|
||||||
|
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||||
|
raise SingleFileComponentError(
|
||||||
|
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
load_method = getattr(class_obj, "from_pretrained")
|
||||||
|
loaded_sub_model = load_method(**loading_kwargs)
|
||||||
|
|
||||||
|
return loaded_sub_model
|
||||||
|
|
||||||
|
|
||||||
|
def _map_component_types_to_config_dict(component_types):
|
||||||
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
|
config_dict = {}
|
||||||
|
component_types.pop("self", None)
|
||||||
|
|
||||||
|
if is_transformers_available():
|
||||||
|
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||||
|
else:
|
||||||
|
transformers_version = "N/A"
|
||||||
|
|
||||||
|
for component_name, component_value in component_types.items():
|
||||||
|
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||||
|
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||||
|
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||||
|
|
||||||
|
is_transformers_model = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(component_value[0], PreTrainedModel)
|
||||||
|
and transformers_version >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
is_transformers_tokenizer = (
|
||||||
|
is_transformers_available()
|
||||||
|
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||||
|
and transformers_version >= version.parse("4.20.0")
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||||
|
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||||
|
|
||||||
|
elif is_scheduler_enum or is_scheduler:
|
||||||
|
if is_scheduler_enum:
|
||||||
|
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||||
|
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||||
|
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||||
|
|
||||||
|
elif is_scheduler:
|
||||||
|
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||||
|
|
||||||
|
elif (
|
||||||
|
is_transformers_model or is_transformers_tokenizer
|
||||||
|
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||||
|
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||||
|
|
||||||
|
else:
|
||||||
|
config_dict[component_name] = [None, None]
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_pipeline_config_dict(pipeline_class):
|
||||||
|
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||||
|
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||||
|
component_types = pipeline_class._get_signature_types()
|
||||||
|
|
||||||
|
# Ignore parameters that are not required for the pipeline
|
||||||
|
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||||
|
config_dict = _map_component_types_to_config_dict(component_types)
|
||||||
|
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _download_diffusers_model_config_from_hub(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir,
|
||||||
|
revision,
|
||||||
|
proxies,
|
||||||
|
force_download=None,
|
||||||
|
local_files_only=None,
|
||||||
|
token=None,
|
||||||
|
):
|
||||||
|
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||||
|
cached_model_path = snapshot_download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
revision=revision,
|
||||||
|
proxies=proxies,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
|
||||||
|
return cached_model_path
|
||||||
|
|
||||||
|
|
||||||
|
class FromSingleFileMixin:
|
||||||
|
"""
|
||||||
|
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@validate_hf_hub_args
|
||||||
|
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
||||||
|
r"""
|
||||||
|
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||||
|
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||||
|
Can be either:
|
||||||
|
- A link to the `.ckpt` file (for example
|
||||||
|
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||||
|
- A path to a *file* containing all pipeline weights.
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
|
Override the default `torch.dtype` and load the model with another dtype.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
original_config_file (`str`, *optional*):
|
||||||
|
The path to the original config file that was used to train the model. If not provided, the config file
|
||||||
|
will be inferred from the checkpoint file.
|
||||||
|
config (`str`, *optional*):
|
||||||
|
Can be either:
|
||||||
|
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||||
|
hosted on the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||||
|
component configs in Diffusers format.
|
||||||
|
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||||
|
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||||
|
is on a network mount or hard drive.
|
||||||
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
|
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||||
|
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||||
|
below for more information.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import StableDiffusionPipeline
|
||||||
|
|
||||||
|
>>> # Download pipeline from huggingface.co and cache.
|
||||||
|
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||||
|
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||||
|
... )
|
||||||
|
|
||||||
|
>>> # Download pipeline from local file
|
||||||
|
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||||
|
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||||
|
|
||||||
|
>>> # Enable float16 and move to GPU
|
||||||
|
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||||
|
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||||
|
... torch_dtype=torch.float16,
|
||||||
|
... )
|
||||||
|
>>> pipeline.to("cuda")
|
||||||
|
```
|
||||||
|
|
||||||
|
"""
|
||||||
|
original_config_file = kwargs.pop("original_config_file", None)
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
original_config = kwargs.pop("original_config", None)
|
||||||
|
|
||||||
|
if original_config_file is not None:
|
||||||
|
deprecation_message = (
|
||||||
|
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||||
|
"please use the `original_config` argument instead."
|
||||||
|
)
|
||||||
|
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||||
|
original_config = original_config_file
|
||||||
|
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", False)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||||
|
|
||||||
|
is_legacy_loading = False
|
||||||
|
|
||||||
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
torch_dtype = torch.float32
|
||||||
|
logger.warning(
|
||||||
|
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||||
|
)
|
||||||
|
|
||||||
|
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||||
|
# These model kwargs should be deprecated
|
||||||
|
scaling_factor = kwargs.get("scaling_factor", None)
|
||||||
|
if scaling_factor is not None:
|
||||||
|
deprecation_message = (
|
||||||
|
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||||
|
"and will be ignored in future versions."
|
||||||
|
)
|
||||||
|
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
|
if original_config is not None:
|
||||||
|
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||||
|
|
||||||
|
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||||
|
|
||||||
|
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||||
|
|
||||||
|
checkpoint = load_single_file_checkpoint(
|
||||||
|
pretrained_model_link_or_path,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = fetch_diffusers_config(checkpoint)
|
||||||
|
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||||
|
else:
|
||||||
|
default_pretrained_model_config_name = config
|
||||||
|
|
||||||
|
if not os.path.isdir(default_pretrained_model_config_name):
|
||||||
|
# Provided config is a repo_id
|
||||||
|
if default_pretrained_model_config_name.count("/") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'The provided config "{config}"'
|
||||||
|
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
# Attempt to download the config files for the pipeline
|
||||||
|
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||||
|
default_pretrained_model_config_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
revision=revision,
|
||||||
|
proxies=proxies,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||||
|
|
||||||
|
except LocalEntryNotFoundError:
|
||||||
|
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||||
|
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||||
|
# to fetch the config files from the hub so that we have a way
|
||||||
|
# to configure the pipeline components.
|
||||||
|
|
||||||
|
if original_config is None:
|
||||||
|
logger.warning(
|
||||||
|
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||||
|
"Attempting to download the necessary config files for this pipeline.\n"
|
||||||
|
)
|
||||||
|
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||||
|
default_pretrained_model_config_name,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
revision=revision,
|
||||||
|
proxies=proxies,
|
||||||
|
force_download=force_download,
|
||||||
|
local_files_only=False,
|
||||||
|
token=token,
|
||||||
|
)
|
||||||
|
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# For backwards compatibility
|
||||||
|
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||||
|
logger.warning(
|
||||||
|
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||||
|
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||||
|
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||||
|
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||||
|
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||||
|
"the necessary config files.\n"
|
||||||
|
)
|
||||||
|
is_legacy_loading = True
|
||||||
|
cached_model_config_path = None
|
||||||
|
|
||||||
|
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||||
|
config_dict["_class_name"] = pipeline_class.__name__
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Provided config is a path to a local directory attempt to load directly.
|
||||||
|
cached_model_config_path = default_pretrained_model_config_name
|
||||||
|
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||||
|
|
||||||
|
# pop out "_ignore_files" as it is only needed for download
|
||||||
|
config_dict.pop("_ignore_files", None)
|
||||||
|
|
||||||
|
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||||
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||||
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||||
|
|
||||||
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||||
|
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||||
|
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
|
# remove `null` components
|
||||||
|
def load_module(name, value):
|
||||||
|
if value[0] is None:
|
||||||
|
return False
|
||||||
|
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||||
|
return False
|
||||||
|
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||||
|
|
||||||
|
for name, (library_name, class_name) in logging.tqdm(
|
||||||
|
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||||
|
):
|
||||||
|
loaded_sub_model = None
|
||||||
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
|
||||||
|
if name in passed_class_obj:
|
||||||
|
loaded_sub_model = passed_class_obj[name]
|
||||||
|
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
loaded_sub_model = load_single_file_sub_model(
|
||||||
|
library_name=library_name,
|
||||||
|
class_name=class_name,
|
||||||
|
name=name,
|
||||||
|
checkpoint=checkpoint,
|
||||||
|
is_pipeline_module=is_pipeline_module,
|
||||||
|
cached_model_config_path=cached_model_config_path,
|
||||||
|
pipelines=pipelines,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
original_config=original_config,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
is_legacy_loading=is_legacy_loading,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
except SingleFileComponentError as e:
|
||||||
|
raise SingleFileComponentError(
|
||||||
|
(
|
||||||
|
f"{e.message}\n"
|
||||||
|
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||||
|
f"\n"
|
||||||
|
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||||
|
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||||
|
f"\n"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
init_kwargs[name] = loaded_sub_model
|
||||||
|
|
||||||
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||||
|
passed_modules = list(passed_class_obj.keys())
|
||||||
|
optional_modules = pipeline_class._optional_components
|
||||||
|
|
||||||
|
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||||
|
for module in missing_modules:
|
||||||
|
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||||
|
elif len(missing_modules) > 0:
|
||||||
|
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# deprecated kwargs
|
||||||
|
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||||
|
if load_safety_checker is not None:
|
||||||
|
deprecation_message = (
|
||||||
|
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||||
|
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||||
|
)
|
||||||
|
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||||
|
|
||||||
|
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||||
|
init_kwargs.update(safety_checker_components)
|
||||||
|
|
||||||
|
pipe = pipeline_class(**init_kwargs)
|
||||||
|
|
||||||
|
return pipe
|
||||||
440
src/diffusers/loaders/single_file/single_file_model.py
Normal file
440
src/diffusers/loaders/single_file/single_file_model.py
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
import re
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
|
from typing_extensions import Self
|
||||||
|
|
||||||
|
from ... import __version__
|
||||||
|
from ...quantizers import DiffusersAutoQuantizer
|
||||||
|
from ...utils import deprecate, is_accelerate_available, logging
|
||||||
|
from .single_file_utils import (
|
||||||
|
SingleFileComponentError,
|
||||||
|
convert_animatediff_checkpoint_to_diffusers,
|
||||||
|
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||||
|
convert_controlnet_checkpoint,
|
||||||
|
convert_flux_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_hunyuan_video_transformer_to_diffusers,
|
||||||
|
convert_ldm_unet_checkpoint,
|
||||||
|
convert_ldm_vae_checkpoint,
|
||||||
|
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_ltx_vae_checkpoint_to_diffusers,
|
||||||
|
convert_lumina2_to_diffusers,
|
||||||
|
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_sana_transformer_to_diffusers,
|
||||||
|
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||||
|
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||||
|
convert_wan_transformer_to_diffusers,
|
||||||
|
convert_wan_vae_to_diffusers,
|
||||||
|
create_controlnet_diffusers_config_from_ldm,
|
||||||
|
create_unet_diffusers_config_from_ldm,
|
||||||
|
create_vae_diffusers_config_from_ldm,
|
||||||
|
fetch_diffusers_config,
|
||||||
|
fetch_original_config,
|
||||||
|
load_single_file_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_accelerate_available():
|
||||||
|
from accelerate import dispatch_model, init_empty_weights
|
||||||
|
|
||||||
|
from ...models.modeling_utils import load_model_dict_into_meta
|
||||||
|
|
||||||
|
|
||||||
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||||
|
"StableCascadeUNet": {
|
||||||
|
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||||
|
},
|
||||||
|
"UNet2DConditionModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||||
|
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||||
|
"default_subfolder": "unet",
|
||||||
|
"legacy_kwargs": {
|
||||||
|
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"AutoencoderKL": {
|
||||||
|
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||||
|
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||||
|
"default_subfolder": "vae",
|
||||||
|
},
|
||||||
|
"ControlNetModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||||
|
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||||
|
},
|
||||||
|
"SD3Transformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"MotionAdapter": {
|
||||||
|
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||||
|
},
|
||||||
|
"SparseControlNetModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||||
|
},
|
||||||
|
"FluxTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"LTXVideoTransformer3DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"AutoencoderKLLTXVideo": {
|
||||||
|
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "vae",
|
||||||
|
},
|
||||||
|
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||||
|
"MochiTransformer3DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"HunyuanVideoTransformer3DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"AuraFlowTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"Lumina2Transformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"SanaTransformer2DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"WanTransformer3DModel": {
|
||||||
|
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||||
|
"default_subfolder": "transformer",
|
||||||
|
},
|
||||||
|
"AutoencoderKLWan": {
|
||||||
|
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||||
|
"default_subfolder": "vae",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_single_file_loadable_mapping_class(cls):
|
||||||
|
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||||
|
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
|
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||||
|
|
||||||
|
if issubclass(cls, loadable_class):
|
||||||
|
return loadable_class_str
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||||
|
parameters = inspect.signature(mapping_fn).parameters
|
||||||
|
|
||||||
|
mapping_kwargs = {}
|
||||||
|
for parameter in parameters:
|
||||||
|
if parameter in kwargs:
|
||||||
|
mapping_kwargs[parameter] = kwargs[parameter]
|
||||||
|
|
||||||
|
return mapping_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
class FromOriginalModelMixin:
|
||||||
|
"""
|
||||||
|
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@validate_hf_hub_args
|
||||||
|
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
||||||
|
r"""
|
||||||
|
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||||
|
is set in evaluation mode (`model.eval()`) by default.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||||
|
Can be either:
|
||||||
|
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||||
|
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||||
|
- A path to a local *file* containing the weights of the component model.
|
||||||
|
- A state dict containing the component model weights.
|
||||||
|
config (`str`, *optional*):
|
||||||
|
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||||
|
on the Hub.
|
||||||
|
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||||
|
configs in Diffusers format.
|
||||||
|
subfolder (`str`, *optional*, defaults to `""`):
|
||||||
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||||
|
original_config (`str`, *optional*):
|
||||||
|
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||||
|
If a dict is provided, it will be used to initialize the model configuration.
|
||||||
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||||
|
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||||
|
dtype is automatically derived from the model's weights.
|
||||||
|
force_download (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||||
|
cached versions if they exist.
|
||||||
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||||
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||||
|
is not used.
|
||||||
|
|
||||||
|
proxies (`Dict[str, str]`, *optional*):
|
||||||
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||||
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||||
|
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||||
|
won't be downloaded from the Hub.
|
||||||
|
token (`str` or *bool*, *optional*):
|
||||||
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||||
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||||
|
revision (`str`, *optional*, defaults to `"main"`):
|
||||||
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||||
|
allowed by Git.
|
||||||
|
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||||
|
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||||
|
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||||
|
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||||
|
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||||
|
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||||
|
method. See example below for more information.
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> from diffusers import StableCascadeUNet
|
||||||
|
|
||||||
|
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||||
|
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||||
|
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
|
if mapping_class_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||||
|
if pretrained_model_link_or_path is not None:
|
||||||
|
deprecation_message = (
|
||||||
|
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||||
|
)
|
||||||
|
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||||
|
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||||
|
|
||||||
|
config = kwargs.pop("config", None)
|
||||||
|
original_config = kwargs.pop("original_config", None)
|
||||||
|
|
||||||
|
if config is not None and original_config is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||||
|
)
|
||||||
|
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
config_revision = kwargs.pop("config_revision", None)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
quantization_config = kwargs.pop("quantization_config", None)
|
||||||
|
device = kwargs.pop("device", None)
|
||||||
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||||
|
|
||||||
|
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||||
|
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||||
|
if quantization_config is not None:
|
||||||
|
user_agent["quant"] = quantization_config.quant_method.value
|
||||||
|
|
||||||
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
torch_dtype = torch.float32
|
||||||
|
logger.warning(
|
||||||
|
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||||
|
checkpoint = pretrained_model_link_or_path_or_dict
|
||||||
|
else:
|
||||||
|
checkpoint = load_single_file_checkpoint(
|
||||||
|
pretrained_model_link_or_path_or_dict,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
if quantization_config is not None:
|
||||||
|
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||||
|
hf_quantizer.validate_environment()
|
||||||
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||||
|
|
||||||
|
else:
|
||||||
|
hf_quantizer = None
|
||||||
|
|
||||||
|
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||||
|
|
||||||
|
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||||
|
if original_config is not None:
|
||||||
|
if "config_mapping_fn" in mapping_functions:
|
||||||
|
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||||
|
else:
|
||||||
|
config_mapping_fn = None
|
||||||
|
|
||||||
|
if config_mapping_fn is None:
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||||
|
"was found to convert the original config to a Diffusers config in"
|
||||||
|
"`diffusers.loaders.single_file_utils`"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(original_config, str):
|
||||||
|
# If original_config is a URL or filepath fetch the original_config dict
|
||||||
|
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||||
|
|
||||||
|
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||||
|
diffusers_model_config = config_mapping_fn(
|
||||||
|
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if config is not None:
|
||||||
|
if isinstance(config, str):
|
||||||
|
default_pretrained_model_config_name = config
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
(
|
||||||
|
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||||
|
"or path to a local Diffusers model repo."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
config = fetch_diffusers_config(checkpoint)
|
||||||
|
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||||
|
|
||||||
|
if "default_subfolder" in mapping_functions:
|
||||||
|
subfolder = mapping_functions["default_subfolder"]
|
||||||
|
|
||||||
|
subfolder = subfolder or config.pop(
|
||||||
|
"subfolder", None
|
||||||
|
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||||
|
|
||||||
|
diffusers_model_config = cls.load_config(
|
||||||
|
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||||
|
subfolder=subfolder,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=config_revision,
|
||||||
|
)
|
||||||
|
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||||
|
|
||||||
|
# Map legacy kwargs to new kwargs
|
||||||
|
if "legacy_kwargs" in mapping_functions:
|
||||||
|
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||||
|
for legacy_key, new_key in legacy_kwargs.items():
|
||||||
|
if legacy_key in kwargs:
|
||||||
|
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||||
|
|
||||||
|
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||||
|
diffusers_model_config.update(model_kwargs)
|
||||||
|
|
||||||
|
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||||
|
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||||
|
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||||
|
)
|
||||||
|
if not diffusers_format_checkpoint:
|
||||||
|
raise SingleFileComponentError(
|
||||||
|
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||||
|
with ctx():
|
||||||
|
model = cls.from_config(diffusers_model_config)
|
||||||
|
|
||||||
|
# Check if `_keep_in_fp32_modules` is not None
|
||||||
|
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||||
|
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||||
|
)
|
||||||
|
if use_keep_in_fp32_modules:
|
||||||
|
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||||
|
if not isinstance(keep_in_fp32_modules, list):
|
||||||
|
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||||
|
|
||||||
|
else:
|
||||||
|
keep_in_fp32_modules = []
|
||||||
|
|
||||||
|
if hf_quantizer is not None:
|
||||||
|
hf_quantizer.preprocess_model(
|
||||||
|
model=model,
|
||||||
|
device_map=None,
|
||||||
|
state_dict=diffusers_format_checkpoint,
|
||||||
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
|
)
|
||||||
|
|
||||||
|
device_map = None
|
||||||
|
if is_accelerate_available():
|
||||||
|
param_device = torch.device(device) if device else torch.device("cpu")
|
||||||
|
empty_state_dict = model.state_dict()
|
||||||
|
unexpected_keys = [
|
||||||
|
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||||
|
]
|
||||||
|
device_map = {"": param_device}
|
||||||
|
load_model_dict_into_meta(
|
||||||
|
model,
|
||||||
|
diffusers_format_checkpoint,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
device_map=device_map,
|
||||||
|
hf_quantizer=hf_quantizer,
|
||||||
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
|
unexpected_keys=unexpected_keys,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||||
|
|
||||||
|
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||||
|
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||||
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if hf_quantizer is not None:
|
||||||
|
hf_quantizer.postprocess_model(model)
|
||||||
|
model.hf_quantizer = hf_quantizer
|
||||||
|
|
||||||
|
if torch_dtype is not None and hf_quantizer is None:
|
||||||
|
model.to(torch_dtype)
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
if device_map is not None:
|
||||||
|
device_map_kwargs = {"device_map": device_map}
|
||||||
|
dispatch_model(model, **device_map_kwargs)
|
||||||
|
|
||||||
|
return model
|
||||||
3295
src/diffusers/loaders/single_file/single_file_utils.py
Normal file
3295
src/diffusers/loaders/single_file/single_file_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,430 +11,17 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
|
||||||
import inspect
|
|
||||||
import re
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from huggingface_hub.utils import validate_hf_hub_args
|
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from .. import __version__
|
from ..utils import deprecate
|
||||||
from ..quantizers import DiffusersAutoQuantizer
|
from .single_file.single_file_model import (
|
||||||
from ..utils import deprecate, is_accelerate_available, logging
|
SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401
|
||||||
from .single_file_utils import (
|
FromOriginalModelMixin,
|
||||||
SingleFileComponentError,
|
|
||||||
convert_animatediff_checkpoint_to_diffusers,
|
|
||||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
|
||||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
|
||||||
convert_controlnet_checkpoint,
|
|
||||||
convert_flux_transformer_checkpoint_to_diffusers,
|
|
||||||
convert_hunyuan_video_transformer_to_diffusers,
|
|
||||||
convert_ldm_unet_checkpoint,
|
|
||||||
convert_ldm_vae_checkpoint,
|
|
||||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
|
||||||
convert_ltx_vae_checkpoint_to_diffusers,
|
|
||||||
convert_lumina2_to_diffusers,
|
|
||||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
|
||||||
convert_sana_transformer_to_diffusers,
|
|
||||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
|
||||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
|
||||||
convert_wan_transformer_to_diffusers,
|
|
||||||
convert_wan_vae_to_diffusers,
|
|
||||||
create_controlnet_diffusers_config_from_ldm,
|
|
||||||
create_unet_diffusers_config_from_ldm,
|
|
||||||
create_vae_diffusers_config_from_ldm,
|
|
||||||
fetch_diffusers_config,
|
|
||||||
fetch_original_config,
|
|
||||||
load_single_file_checkpoint,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
class FromOriginalModelMixin(FromOriginalModelMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead."
|
||||||
if is_accelerate_available():
|
deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message)
|
||||||
from accelerate import dispatch_model, init_empty_weights
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
from ..models.modeling_utils import load_model_dict_into_meta
|
|
||||||
|
|
||||||
|
|
||||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
|
||||||
"StableCascadeUNet": {
|
|
||||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
|
||||||
},
|
|
||||||
"UNet2DConditionModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
|
||||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
|
||||||
"default_subfolder": "unet",
|
|
||||||
"legacy_kwargs": {
|
|
||||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"AutoencoderKL": {
|
|
||||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
|
||||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
|
||||||
"default_subfolder": "vae",
|
|
||||||
},
|
|
||||||
"ControlNetModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
|
||||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
|
||||||
},
|
|
||||||
"SD3Transformer2DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"MotionAdapter": {
|
|
||||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
|
||||||
},
|
|
||||||
"SparseControlNetModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
|
||||||
},
|
|
||||||
"FluxTransformer2DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"LTXVideoTransformer3DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"AutoencoderKLLTXVideo": {
|
|
||||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "vae",
|
|
||||||
},
|
|
||||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
|
||||||
"MochiTransformer3DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"HunyuanVideoTransformer3DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"AuraFlowTransformer2DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"Lumina2Transformer2DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"SanaTransformer2DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"WanTransformer3DModel": {
|
|
||||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
|
||||||
"default_subfolder": "transformer",
|
|
||||||
},
|
|
||||||
"AutoencoderKLWan": {
|
|
||||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
|
||||||
"default_subfolder": "vae",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _get_single_file_loadable_mapping_class(cls):
|
|
||||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
|
||||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
|
||||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
|
||||||
|
|
||||||
if issubclass(cls, loadable_class):
|
|
||||||
return loadable_class_str
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
|
||||||
parameters = inspect.signature(mapping_fn).parameters
|
|
||||||
|
|
||||||
mapping_kwargs = {}
|
|
||||||
for parameter in parameters:
|
|
||||||
if parameter in kwargs:
|
|
||||||
mapping_kwargs[parameter] = kwargs[parameter]
|
|
||||||
|
|
||||||
return mapping_kwargs
|
|
||||||
|
|
||||||
|
|
||||||
class FromOriginalModelMixin:
|
|
||||||
"""
|
|
||||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@validate_hf_hub_args
|
|
||||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
|
||||||
r"""
|
|
||||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
|
||||||
is set in evaluation mode (`model.eval()`) by default.
|
|
||||||
|
|
||||||
Parameters:
|
|
||||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
|
||||||
Can be either:
|
|
||||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
|
||||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
|
||||||
- A path to a local *file* containing the weights of the component model.
|
|
||||||
- A state dict containing the component model weights.
|
|
||||||
config (`str`, *optional*):
|
|
||||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
|
||||||
on the Hub.
|
|
||||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
|
||||||
configs in Diffusers format.
|
|
||||||
subfolder (`str`, *optional*, defaults to `""`):
|
|
||||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
|
||||||
original_config (`str`, *optional*):
|
|
||||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
|
||||||
If a dict is provided, it will be used to initialize the model configuration.
|
|
||||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
|
||||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
|
||||||
dtype is automatically derived from the model's weights.
|
|
||||||
force_download (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
|
||||||
cached versions if they exist.
|
|
||||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
|
||||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
|
||||||
is not used.
|
|
||||||
|
|
||||||
proxies (`Dict[str, str]`, *optional*):
|
|
||||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
|
||||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
|
||||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
|
||||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
|
||||||
won't be downloaded from the Hub.
|
|
||||||
token (`str` or *bool*, *optional*):
|
|
||||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
|
||||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
|
||||||
revision (`str`, *optional*, defaults to `"main"`):
|
|
||||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
|
||||||
allowed by Git.
|
|
||||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
|
||||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
|
||||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
|
||||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
|
||||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
|
||||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
|
||||||
method. See example below for more information.
|
|
||||||
|
|
||||||
```py
|
|
||||||
>>> from diffusers import StableCascadeUNet
|
|
||||||
|
|
||||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
|
||||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
|
||||||
```
|
|
||||||
"""
|
|
||||||
|
|
||||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
|
||||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
|
||||||
if mapping_class_name is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
|
||||||
if pretrained_model_link_or_path is not None:
|
|
||||||
deprecation_message = (
|
|
||||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
|
||||||
)
|
|
||||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
|
||||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
|
||||||
|
|
||||||
config = kwargs.pop("config", None)
|
|
||||||
original_config = kwargs.pop("original_config", None)
|
|
||||||
|
|
||||||
if config is not None and original_config is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
|
||||||
)
|
|
||||||
|
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
config_revision = kwargs.pop("config_revision", None)
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
||||||
quantization_config = kwargs.pop("quantization_config", None)
|
|
||||||
device = kwargs.pop("device", None)
|
|
||||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
|
||||||
|
|
||||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
|
||||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
|
||||||
if quantization_config is not None:
|
|
||||||
user_agent["quant"] = quantization_config.quant_method.value
|
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
|
||||||
torch_dtype = torch.float32
|
|
||||||
logger.warning(
|
|
||||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
|
||||||
checkpoint = pretrained_model_link_or_path_or_dict
|
|
||||||
else:
|
|
||||||
checkpoint = load_single_file_checkpoint(
|
|
||||||
pretrained_model_link_or_path_or_dict,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
token=token,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
revision=revision,
|
|
||||||
disable_mmap=disable_mmap,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
if quantization_config is not None:
|
|
||||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
|
||||||
hf_quantizer.validate_environment()
|
|
||||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
|
||||||
|
|
||||||
else:
|
|
||||||
hf_quantizer = None
|
|
||||||
|
|
||||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
|
||||||
|
|
||||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
|
||||||
if original_config is not None:
|
|
||||||
if "config_mapping_fn" in mapping_functions:
|
|
||||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
|
||||||
else:
|
|
||||||
config_mapping_fn = None
|
|
||||||
|
|
||||||
if config_mapping_fn is None:
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
|
||||||
"was found to convert the original config to a Diffusers config in"
|
|
||||||
"`diffusers.loaders.single_file_utils`"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(original_config, str):
|
|
||||||
# If original_config is a URL or filepath fetch the original_config dict
|
|
||||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
|
||||||
|
|
||||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
|
||||||
diffusers_model_config = config_mapping_fn(
|
|
||||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if config is not None:
|
|
||||||
if isinstance(config, str):
|
|
||||||
default_pretrained_model_config_name = config
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
(
|
|
||||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
|
||||||
"or path to a local Diffusers model repo."
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
config = fetch_diffusers_config(checkpoint)
|
|
||||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
|
||||||
|
|
||||||
if "default_subfolder" in mapping_functions:
|
|
||||||
subfolder = mapping_functions["default_subfolder"]
|
|
||||||
|
|
||||||
subfolder = subfolder or config.pop(
|
|
||||||
"subfolder", None
|
|
||||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
|
||||||
|
|
||||||
diffusers_model_config = cls.load_config(
|
|
||||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
|
||||||
subfolder=subfolder,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=config_revision,
|
|
||||||
)
|
|
||||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
|
||||||
|
|
||||||
# Map legacy kwargs to new kwargs
|
|
||||||
if "legacy_kwargs" in mapping_functions:
|
|
||||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
|
||||||
for legacy_key, new_key in legacy_kwargs.items():
|
|
||||||
if legacy_key in kwargs:
|
|
||||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
|
||||||
|
|
||||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
|
||||||
diffusers_model_config.update(model_kwargs)
|
|
||||||
|
|
||||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
|
||||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
|
||||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
|
||||||
)
|
|
||||||
if not diffusers_format_checkpoint:
|
|
||||||
raise SingleFileComponentError(
|
|
||||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
|
||||||
with ctx():
|
|
||||||
model = cls.from_config(diffusers_model_config)
|
|
||||||
|
|
||||||
# Check if `_keep_in_fp32_modules` is not None
|
|
||||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
|
||||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
|
||||||
)
|
|
||||||
if use_keep_in_fp32_modules:
|
|
||||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
|
||||||
if not isinstance(keep_in_fp32_modules, list):
|
|
||||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
|
||||||
|
|
||||||
else:
|
|
||||||
keep_in_fp32_modules = []
|
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
|
||||||
hf_quantizer.preprocess_model(
|
|
||||||
model=model,
|
|
||||||
device_map=None,
|
|
||||||
state_dict=diffusers_format_checkpoint,
|
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
|
||||||
)
|
|
||||||
|
|
||||||
device_map = None
|
|
||||||
if is_accelerate_available():
|
|
||||||
param_device = torch.device(device) if device else torch.device("cpu")
|
|
||||||
empty_state_dict = model.state_dict()
|
|
||||||
unexpected_keys = [
|
|
||||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
|
||||||
]
|
|
||||||
device_map = {"": param_device}
|
|
||||||
load_model_dict_into_meta(
|
|
||||||
model,
|
|
||||||
diffusers_format_checkpoint,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
device_map=device_map,
|
|
||||||
hf_quantizer=hf_quantizer,
|
|
||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
|
||||||
unexpected_keys=unexpected_keys,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
|
||||||
|
|
||||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
|
||||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if hf_quantizer is not None:
|
|
||||||
hf_quantizer.postprocess_model(model)
|
|
||||||
model.hf_quantizer = hf_quantizer
|
|
||||||
|
|
||||||
if torch_dtype is not None and hf_quantizer is None:
|
|
||||||
model.to(torch_dtype)
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
|
|
||||||
if device_map is not None:
|
|
||||||
device_map_kwargs = {"device_map": device_map}
|
|
||||||
dispatch_model(model, **device_map_kwargs)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -11,170 +11,13 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from contextlib import nullcontext
|
|
||||||
|
|
||||||
from ..models.embeddings import (
|
from ..utils import deprecate
|
||||||
ImageProjection,
|
from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin
|
||||||
MultiIPAdapterImageProjection,
|
|
||||||
)
|
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
|
||||||
from ..utils import (
|
|
||||||
is_accelerate_available,
|
|
||||||
is_torch_version,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if is_accelerate_available():
|
class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin):
|
||||||
pass
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead."
|
||||||
logger = logging.get_logger(__name__)
|
deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message)
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
class FluxTransformer2DLoadersMixin:
|
|
||||||
"""
|
|
||||||
Load layers into a [`FluxTransformer2DModel`].
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
else:
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
updated_state_dict = {}
|
|
||||||
image_projection = None
|
|
||||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
|
||||||
|
|
||||||
if "proj.weight" in state_dict:
|
|
||||||
# IP-Adapter
|
|
||||||
num_image_text_embeds = 4
|
|
||||||
if state_dict["proj.weight"].shape[0] == 65536:
|
|
||||||
num_image_text_embeds = 16
|
|
||||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
|
||||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
|
||||||
|
|
||||||
with init_context():
|
|
||||||
image_projection = ImageProjection(
|
|
||||||
cross_attention_dim=cross_attention_dim,
|
|
||||||
image_embed_dim=clip_embeddings_dim,
|
|
||||||
num_image_text_embeds=num_image_text_embeds,
|
|
||||||
)
|
|
||||||
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
diffusers_name = key.replace("proj", "image_embeds")
|
|
||||||
updated_state_dict[diffusers_name] = value
|
|
||||||
|
|
||||||
if not low_cpu_mem_usage:
|
|
||||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
|
||||||
else:
|
|
||||||
device_map = {"": self.device}
|
|
||||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
|
||||||
|
|
||||||
return image_projection
|
|
||||||
|
|
||||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
|
||||||
from ..models.attention_processor import (
|
|
||||||
FluxIPAdapterJointAttnProcessor2_0,
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
else:
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# set ip-adapter cross-attention processors & load state_dict
|
|
||||||
attn_procs = {}
|
|
||||||
key_id = 0
|
|
||||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
|
||||||
for name in self.attn_processors.keys():
|
|
||||||
if name.startswith("single_transformer_blocks"):
|
|
||||||
attn_processor_class = self.attn_processors[name].__class__
|
|
||||||
attn_procs[name] = attn_processor_class()
|
|
||||||
else:
|
|
||||||
cross_attention_dim = self.config.joint_attention_dim
|
|
||||||
hidden_size = self.inner_dim
|
|
||||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
|
||||||
num_image_text_embeds = []
|
|
||||||
for state_dict in state_dicts:
|
|
||||||
if "proj.weight" in state_dict["image_proj"]:
|
|
||||||
num_image_text_embed = 4
|
|
||||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
|
||||||
num_image_text_embed = 16
|
|
||||||
# IP-Adapter
|
|
||||||
num_image_text_embeds += [num_image_text_embed]
|
|
||||||
|
|
||||||
with init_context():
|
|
||||||
attn_procs[name] = attn_processor_class(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
cross_attention_dim=cross_attention_dim,
|
|
||||||
scale=1.0,
|
|
||||||
num_tokens=num_image_text_embeds,
|
|
||||||
dtype=self.dtype,
|
|
||||||
device=self.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
value_dict = {}
|
|
||||||
for i, state_dict in enumerate(state_dicts):
|
|
||||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
|
||||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
|
||||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
|
||||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
|
||||||
|
|
||||||
if not low_cpu_mem_usage:
|
|
||||||
attn_procs[name].load_state_dict(value_dict)
|
|
||||||
else:
|
|
||||||
device_map = {"": self.device}
|
|
||||||
dtype = self.dtype
|
|
||||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
|
||||||
|
|
||||||
key_id += 1
|
|
||||||
|
|
||||||
return attn_procs
|
|
||||||
|
|
||||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
|
||||||
if not isinstance(state_dicts, list):
|
|
||||||
state_dicts = [state_dicts]
|
|
||||||
|
|
||||||
self.encoder_hid_proj = None
|
|
||||||
|
|
||||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
||||||
self.set_attn_processor(attn_procs)
|
|
||||||
|
|
||||||
image_projection_layers = []
|
|
||||||
for state_dict in state_dicts:
|
|
||||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
|
||||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
|
||||||
)
|
|
||||||
image_projection_layers.append(image_projection_layer)
|
|
||||||
|
|
||||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
|
||||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
|
||||||
|
|||||||
@@ -11,160 +11,12 @@
|
|||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from contextlib import nullcontext
|
from ..utils import deprecate
|
||||||
from typing import Dict
|
from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||||
|
|
||||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
|
||||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
|
||||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead."
|
||||||
class SD3Transformer2DLoadersMixin:
|
deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message)
|
||||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
def _convert_ip_adapter_attn_to_diffusers(
|
|
||||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
|
||||||
) -> Dict:
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
else:
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# IP-Adapter cross attention parameters
|
|
||||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
|
||||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
|
||||||
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
|
||||||
|
|
||||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
|
||||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
|
||||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
|
||||||
for key, weights in state_dict.items():
|
|
||||||
idx, name = key.split(".", maxsplit=1)
|
|
||||||
layer_state_dict[int(idx)][name] = weights
|
|
||||||
|
|
||||||
# Create IP-Adapter attention processor & load state_dict
|
|
||||||
attn_procs = {}
|
|
||||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
|
||||||
for idx, name in enumerate(self.attn_processors.keys()):
|
|
||||||
with init_context():
|
|
||||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
|
||||||
hidden_size=hidden_size,
|
|
||||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
|
||||||
head_dim=self.config.attention_head_dim,
|
|
||||||
timesteps_emb_dim=timesteps_emb_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not low_cpu_mem_usage:
|
|
||||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
|
||||||
else:
|
|
||||||
device_map = {"": self.device}
|
|
||||||
load_model_dict_into_meta(
|
|
||||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
return attn_procs
|
|
||||||
|
|
||||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
|
||||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
|
||||||
) -> IPAdapterTimeImageProjection:
|
|
||||||
if low_cpu_mem_usage:
|
|
||||||
if is_accelerate_available():
|
|
||||||
from accelerate import init_empty_weights
|
|
||||||
|
|
||||||
else:
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
|
||||||
|
|
||||||
# Convert to diffusers
|
|
||||||
updated_state_dict = {}
|
|
||||||
for key, value in state_dict.items():
|
|
||||||
# InstantX/SD3.5-Large-IP-Adapter
|
|
||||||
if key.startswith("layers."):
|
|
||||||
idx = key.split(".")[1]
|
|
||||||
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
|
||||||
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
|
||||||
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
|
||||||
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
|
||||||
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
|
||||||
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
|
||||||
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
|
||||||
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
|
||||||
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
|
||||||
updated_state_dict[key] = value
|
|
||||||
|
|
||||||
# Image projetion parameters
|
|
||||||
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
|
||||||
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
|
||||||
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
|
||||||
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
|
||||||
num_queries = updated_state_dict["latents"].shape[1]
|
|
||||||
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
|
||||||
|
|
||||||
# Image projection
|
|
||||||
with init_context():
|
|
||||||
image_proj = IPAdapterTimeImageProjection(
|
|
||||||
embed_dim=embed_dim,
|
|
||||||
output_dim=output_dim,
|
|
||||||
hidden_dim=hidden_dim,
|
|
||||||
heads=heads,
|
|
||||||
num_queries=num_queries,
|
|
||||||
timestep_in_dim=timestep_in_dim,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not low_cpu_mem_usage:
|
|
||||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
|
||||||
else:
|
|
||||||
device_map = {"": self.device}
|
|
||||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
|
||||||
|
|
||||||
return image_proj
|
|
||||||
|
|
||||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
|
||||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
state_dict (`Dict`):
|
|
||||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
|
||||||
"image_proj", which contains parameters for image projection net.
|
|
||||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
|
||||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
|
||||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
|
||||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
|
||||||
argument to `True` will raise an error.
|
|
||||||
"""
|
|
||||||
|
|
||||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
|
||||||
self.set_attn_processor(attn_procs)
|
|
||||||
|
|
||||||
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
|
||||||
|
|||||||
5
src/diffusers/loaders/unet/__init__.py
Normal file
5
src/diffusers/loaders/unet/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
from ...utils import is_torch_available
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from .unet import UNet2DConditionLoadersMixin
|
||||||
@@ -22,7 +22,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from huggingface_hub.utils import validate_hf_hub_args
|
from huggingface_hub.utils import validate_hf_hub_args
|
||||||
|
|
||||||
from ..models.embeddings import (
|
from ...models.embeddings import (
|
||||||
ImageProjection,
|
ImageProjection,
|
||||||
IPAdapterFaceIDImageProjection,
|
IPAdapterFaceIDImageProjection,
|
||||||
IPAdapterFaceIDPlusImageProjection,
|
IPAdapterFaceIDPlusImageProjection,
|
||||||
@@ -30,8 +30,8 @@ from ..models.embeddings import (
|
|||||||
IPAdapterPlusImageProjection,
|
IPAdapterPlusImageProjection,
|
||||||
MultiIPAdapterImageProjection,
|
MultiIPAdapterImageProjection,
|
||||||
)
|
)
|
||||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||||
from ..utils import (
|
from ...utils import (
|
||||||
USE_PEFT_BACKEND,
|
USE_PEFT_BACKEND,
|
||||||
_get_model_file,
|
_get_model_file,
|
||||||
convert_unet_state_dict_to_peft,
|
convert_unet_state_dict_to_peft,
|
||||||
@@ -43,9 +43,9 @@ from ..utils import (
|
|||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from .lora_base import _func_optionally_disable_offloading
|
from ..lora.lora_base import _func_optionally_disable_offloading
|
||||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||||
from .utils import AttnProcsLayers
|
from ..utils import AttnProcsLayers
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -247,7 +247,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
# Unsafe code />
|
# Unsafe code />
|
||||||
|
|
||||||
def _process_custom_diffusion(self, state_dict):
|
def _process_custom_diffusion(self, state_dict):
|
||||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
from ...models.attention_processor import CustomDiffusionAttnProcessor
|
||||||
|
|
||||||
attn_processors = {}
|
attn_processors = {}
|
||||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||||
@@ -395,7 +395,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||||
def _optionally_disable_offloading(cls, _pipeline):
|
def _optionally_disable_offloading(cls, _pipeline):
|
||||||
"""
|
"""
|
||||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||||
@@ -451,7 +451,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
from ..models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
CustomDiffusionAttnProcessor,
|
CustomDiffusionAttnProcessor,
|
||||||
CustomDiffusionAttnProcessor2_0,
|
CustomDiffusionAttnProcessor2_0,
|
||||||
CustomDiffusionXFormersAttnProcessor,
|
CustomDiffusionXFormersAttnProcessor,
|
||||||
@@ -513,7 +513,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
logger.info(f"Model weights saved in {save_path}")
|
logger.info(f"Model weights saved in {save_path}")
|
||||||
|
|
||||||
def _get_custom_diffusion_state_dict(self):
|
def _get_custom_diffusion_state_dict(self):
|
||||||
from ..models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
CustomDiffusionAttnProcessor,
|
CustomDiffusionAttnProcessor,
|
||||||
CustomDiffusionAttnProcessor2_0,
|
CustomDiffusionAttnProcessor2_0,
|
||||||
CustomDiffusionXFormersAttnProcessor,
|
CustomDiffusionXFormersAttnProcessor,
|
||||||
@@ -759,7 +759,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
return image_projection
|
return image_projection
|
||||||
|
|
||||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||||
from ..models.attention_processor import (
|
from ...models.attention_processor import (
|
||||||
IPAdapterAttnProcessor,
|
IPAdapterAttnProcessor,
|
||||||
IPAdapterAttnProcessor2_0,
|
IPAdapterAttnProcessor2_0,
|
||||||
IPAdapterXFormersAttnProcessor,
|
IPAdapterXFormersAttnProcessor,
|
||||||
@@ -14,12 +14,12 @@
|
|||||||
import copy
|
import copy
|
||||||
from typing import TYPE_CHECKING, Dict, List, Union
|
from typing import TYPE_CHECKING, Dict, List, Union
|
||||||
|
|
||||||
from ..utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
# import here to avoid circular imports
|
# import here to avoid circular imports
|
||||||
from ..models import UNet2DConditionModel
|
from ...models import UNet2DConditionModel
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
@@ -17,8 +17,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
|
||||||
from ...utils import deprecate
|
from ...utils import deprecate
|
||||||
from ...utils.accelerate_utils import apply_forward_hook
|
from ...utils.accelerate_utils import apply_forward_hook
|
||||||
from ..attention_processor import (
|
from ..attention_processor import (
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
from ...loaders import FromOriginalModelMixin
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.accelerate_utils import apply_forward_hook
|
from ...utils.accelerate_utils import apply_forward_hook
|
||||||
from ..activations import get_activation
|
from ..activations import get_activation
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from torch import nn
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
from ...loaders import FromOriginalModelMixin
|
||||||
from ...utils import BaseOutput, logging
|
from ...utils import BaseOutput, logging
|
||||||
from ..attention_processor import (
|
from ..attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
from ...loaders import FromOriginalModelMixin
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..attention_processor import (
|
from ..attention_processor import (
|
||||||
ADDED_KV_ATTENTION_PROCESSORS,
|
ADDED_KV_ATTENTION_PROCESSORS,
|
||||||
|
|||||||
@@ -20,8 +20,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
|
||||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||||
from ..attention import LuminaFeedForward
|
from ..attention import LuminaFeedForward
|
||||||
from ..attention_processor import Attention
|
from ..attention_processor import Attention
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
|
||||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||||
from ...utils.torch_utils import maybe_allow_in_graph
|
from ...utils.torch_utils import maybe_allow_in_graph
|
||||||
from ..attention import FeedForward
|
from ..attention import FeedForward
|
||||||
|
|||||||
@@ -19,8 +19,7 @@ import torch.nn as nn
|
|||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from ...configuration_utils import ConfigMixin, register_to_config
|
from ...configuration_utils import ConfigMixin, register_to_config
|
||||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
|
||||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||||
from ..activations import get_activation
|
from ..activations import get_activation
|
||||||
from ..attention_processor import (
|
from ..attention_processor import (
|
||||||
|
|||||||
@@ -352,7 +352,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
|
|
||||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||||
@@ -403,7 +403,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Change the transformer config to mimic a real use case.
|
# Change the transformer config to mimic a real use case.
|
||||||
@@ -486,7 +486,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||||
@@ -541,7 +541,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||||
@@ -590,7 +590,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
pipe = pipe.to(torch_device)
|
pipe = pipe.to(torch_device)
|
||||||
pipe.set_progress_bar_config(disable=None)
|
pipe.set_progress_bar_config(disable=None)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||||
@@ -653,7 +653,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||||
}
|
}
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||||
@@ -668,7 +668,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Change the transformer config to mimic a real use case.
|
# Change the transformer config to mimic a real use case.
|
||||||
@@ -734,7 +734,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
|||||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.DEBUG)
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
# Change the transformer config to mimic a real use case.
|
# Change the transformer config to mimic a real use case.
|
||||||
|
|||||||
@@ -1017,7 +1017,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
)
|
)
|
||||||
|
|
||||||
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||||
logger.setLevel(30)
|
logger.setLevel(30)
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
||||||
@@ -1824,7 +1824,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
elif lora_module == "text_encoder_2":
|
elif lora_module == "text_encoder_2":
|
||||||
prefix = "text_encoder_2"
|
prefix = "text_encoder_2"
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||||
logger.setLevel(logging.WARNING)
|
logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
with CaptureLogger(logger) as cap_logger:
|
with CaptureLogger(logger) as cap_logger:
|
||||||
@@ -1925,7 +1925,7 @@ class PeftLoraLoaderMixinTests:
|
|||||||
|
|
||||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||||
|
|
||||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||||
logger.setLevel(logging.INFO)
|
logger.setLevel(logging.INFO)
|
||||||
|
|
||||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import requests
|
|||||||
import torch
|
import torch
|
||||||
from huggingface_hub import hf_hub_download, snapshot_download
|
from huggingface_hub import hf_hub_download, snapshot_download
|
||||||
|
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.models.attention_processor import AttnProcessor
|
from diffusers.models.attention_processor import AttnProcessor
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
numpy_cosine_similarity_distance,
|
numpy_cosine_similarity_distance,
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import AutoencoderKL
|
||||||
AutoencoderKL,
|
|
||||||
)
|
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
|
|||||||
@@ -16,9 +16,7 @@
|
|||||||
import gc
|
import gc
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import AutoencoderKLWan
|
||||||
AutoencoderKLWan,
|
|
||||||
)
|
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
|
|||||||
@@ -18,9 +18,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import WanTransformer3DModel
|
||||||
WanTransformer3DModel,
|
|
||||||
)
|
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
|
|||||||
@@ -3,9 +3,7 @@ import unittest
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import (
|
from diffusers import SanaTransformer2DModel
|
||||||
SanaTransformer2DModel,
|
|
||||||
)
|
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
enable_full_determinism,
|
enable_full_determinism,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
|
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from diffusers import (
|
|||||||
StableDiffusionXLAdapterPipeline,
|
StableDiffusionXLAdapterPipeline,
|
||||||
T2IAdapter,
|
T2IAdapter,
|
||||||
)
|
)
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||||
from diffusers.utils import load_image
|
from diffusers.utils import load_image
|
||||||
from diffusers.utils.testing_utils import (
|
from diffusers.utils.testing_utils import (
|
||||||
backend_empty_cache,
|
backend_empty_cache,
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
|||||||
},
|
},
|
||||||
"LoRA Mixins": {
|
"LoRA Mixins": {
|
||||||
"doc_path": "docs/source/en/api/loaders/lora.md",
|
"doc_path": "docs/source/en/api/loaders/lora.md",
|
||||||
"src_path": "src/diffusers/loaders/lora_pipeline.py",
|
"src_path": "src/diffusers/loaders/lora/lora_pipeline.py",
|
||||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||||
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
|
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user