mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 21:14:44 +08:00
Compare commits
31 Commits
vae-tests-
...
folderize-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d8621ec72 | ||
|
|
0c37895440 | ||
|
|
9bebdf225d | ||
|
|
c05114d5ec | ||
|
|
a57a5ab4c0 | ||
|
|
4b1c7dc81a | ||
|
|
1590325a60 | ||
|
|
e4dd7c5333 | ||
|
|
d6430c79a3 | ||
|
|
1597ae6ac9 | ||
|
|
11a23d11fe | ||
|
|
6b8b225aca | ||
|
|
27d2401e59 | ||
|
|
1ddfe14220 | ||
|
|
0e8d1d25eb | ||
|
|
546446ae21 | ||
|
|
ea3f0b8d68 | ||
|
|
f0ea9ff2e2 | ||
|
|
1b7c286974 | ||
|
|
6138cc1720 | ||
|
|
ea0ce4bfab | ||
|
|
f2aa2f91dc | ||
|
|
4faac73219 | ||
|
|
d870e3c9a6 | ||
|
|
178b884673 | ||
|
|
2da3cb4a8c | ||
|
|
ea3ba4f431 | ||
|
|
21b2566933 | ||
|
|
a71334b861 | ||
|
|
eb47a67d50 | ||
|
|
8267677a24 |
@@ -22,11 +22,11 @@ Learn how to load an IP-Adapter checkpoint and image in the IP-Adapter [loading]
|
||||
|
||||
## IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.IPAdapterMixin
|
||||
[[autodoc]] loaders.ip_adapter.ip_adapter.IPAdapterMixin
|
||||
|
||||
## SD3IPAdapterMixin
|
||||
|
||||
[[autodoc]] loaders.ip_adapter.SD3IPAdapterMixin
|
||||
[[autodoc]] loaders.ip_adapter.ip_adapter.SD3IPAdapterMixin
|
||||
- all
|
||||
- is_ip_adapter_active
|
||||
|
||||
|
||||
@@ -39,58 +39,66 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## StableDiffusionLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionLoraLoaderMixin
|
||||
|
||||
## StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.StableDiffusionXLLoraLoaderMixin
|
||||
|
||||
## SD3LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SD3LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.SD3LoraLoaderMixin
|
||||
|
||||
## FluxLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.FluxLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.FluxLoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
## Mochi1LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.Mochi1LoraLoaderMixin
|
||||
## AuraFlowLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.AuraFlowLoraLoaderMixin
|
||||
|
||||
## LTXVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.LTXVideoLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.LTXVideoLoraLoaderMixin
|
||||
|
||||
## SanaLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.SanaLoraLoaderMixin
|
||||
|
||||
## HunyuanVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
|
||||
## Lumina2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogView4LoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.Lumina2LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.WanLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## CogView4LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.CogView4LoraLoaderMixin
|
||||
|
||||
## WanLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora.lora_pipeline.WanLoraLoaderMixin
|
||||
|
||||
## AmusedLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.AmusedLoraLoaderMixin
|
||||
[[autodoc]] loaders.lora.lora_pipeline.AmusedLoraLoaderMixin
|
||||
|
||||
## HiDreamImageLoraLoaderMixin
|
||||
|
||||
@@ -98,4 +106,4 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## LoraBaseMixin
|
||||
|
||||
[[autodoc]] loaders.lora_base.LoraBaseMixin
|
||||
[[autodoc]] loaders.lora.lora_base.LoraBaseMixin
|
||||
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# SD3Transformer2D
|
||||
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and SD3Transformer2DModel, check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
This class is useful when *only* loading weights into a [`SD3Transformer2DModel`]. If you need to load weights into the text encoder or a text encoder and [SD3Transformer2DModel], check [`SD3LoraLoaderMixin`](lora#diffusers.loaders.SD3LoraLoaderMixin) class instead.
|
||||
|
||||
The [`SD3Transformer2DLoadersMixin`] class currently only loads IP-Adapter weights, but will be used in the future to save weights and load LoRAs.
|
||||
|
||||
@@ -24,6 +24,6 @@ To learn more about how to load LoRA weights, see the [LoRA](../../using-diffuse
|
||||
|
||||
## SD3Transformer2DLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
[[autodoc]] loaders.ip_adapter.transformer_sd3.SD3Transformer2DLoadersMixin
|
||||
- all
|
||||
- _load_ip_adapter_weights
|
||||
@@ -54,14 +54,14 @@ if is_transformers_available():
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["ip_adapter.transformer_flux"] = ["FluxTransformer2DLoadersMixin"]
|
||||
_import_structure["ip_adapter.transformer_sd3"] = ["SD3Transformer2DLoadersMixin"]
|
||||
_import_structure["single_file.single_file_model"] = ["FromOriginalModelMixin"]
|
||||
_import_structure["unet.unet"] = ["UNet2DConditionLoadersMixin"]
|
||||
_import_structure["utils"] = ["AttnProcsLayers"]
|
||||
if is_transformers_available():
|
||||
_import_structure["single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora_pipeline"] = [
|
||||
_import_structure["single_file.single_file"] = ["FromSingleFileMixin"]
|
||||
_import_structure["lora.lora_pipeline"] = [
|
||||
"AmusedLoraLoaderMixin",
|
||||
"StableDiffusionLoraLoaderMixin",
|
||||
"SD3LoraLoaderMixin",
|
||||
@@ -80,7 +80,7 @@ if is_torch_available():
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
]
|
||||
_import_structure["textual_inversion"] = ["TextualInversionLoaderMixin"]
|
||||
_import_structure["ip_adapter"] = [
|
||||
_import_structure["ip_adapter.ip_adapter"] = [
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
@@ -91,19 +91,14 @@ _import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
from .ip_adapter import FluxTransformer2DLoadersMixin, SD3Transformer2DLoadersMixin
|
||||
from .single_file import FromOriginalModelMixin
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
from .utils import AttnProcsLayers
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||
from .lora import (
|
||||
AmusedLoraLoaderMixin,
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
@@ -111,6 +106,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraBaseMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
|
||||
@@ -12,868 +12,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
from ..utils import deprecate
|
||||
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_detailed_type,
|
||||
_get_model_file,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
class IPAdapterMixin(IPAdapterMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import IPAdapterMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.IPAdapterMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ..models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
class FluxIPAdapterMixin(FluxIPAdapterMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxIPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import FluxIPAdapterMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.FluxIPAdapterMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class IPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||
|
||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||
raise ValueError(
|
||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||
)
|
||||
|
||||
# Scalars are transformed to lists with length num_layers
|
||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||
|
||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
class SD3IPAdapterMixin(SD3IPAdapterMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3IPAdapterMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.ip_adapter import SD3IPAdapterMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.SD3IPAdapterMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
9
src/diffusers/loaders/ip_adapter/__init__.py
Normal file
9
src/diffusers/loaders/ip_adapter/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from ...utils.import_utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .transformer_flux import FluxTransformer2DLoadersMixin
|
||||
from .transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .ip_adapter import FluxIPAdapterMixin, IPAdapterMixin, SD3IPAdapterMixin
|
||||
879
src/diffusers/loaders/ip_adapter/ip_adapter.py
Normal file
879
src/diffusers/loaders/ip_adapter/ip_adapter.py
Normal file
@@ -0,0 +1,879 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors import safe_open
|
||||
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_detailed_type,
|
||||
_get_model_file,
|
||||
_is_valid_type,
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
from ..unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection, SiglipImageProcessor, SiglipVisionModel
|
||||
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor,
|
||||
AttnProcessor2_0,
|
||||
FluxAttnProcessor2_0,
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
JointAttnProcessor2_0,
|
||||
SD3IPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class IPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
torch_dtype=self.dtype,
|
||||
).to(self.device)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into unet
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
weight_name: Union[str, List[str]],
|
||||
subfolder: Optional[Union[str, List[str]]] = "",
|
||||
image_encoder_pretrained_model_name_or_path: Optional[str] = "image_encoder",
|
||||
image_encoder_subfolder: Optional[str] = "",
|
||||
image_encoder_dtype: torch.dtype = torch.float16,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`weight_name`.
|
||||
image_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `./image_encoder`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `openai/clip-vit-large-patch14`) of a pretrained model
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
image_proj_keys = ["ip_adapter_proj_model.", "image_proj."]
|
||||
ip_adapter_keys = ["double_blocks.", "ip_adapter."]
|
||||
for key in f.keys():
|
||||
if any(key.startswith(prefix) for prefix in image_proj_keys):
|
||||
diffusers_name = ".".join(key.split(".")[1:])
|
||||
state_dict["image_proj"][diffusers_name] = f.get_tensor(key)
|
||||
elif any(key.startswith(prefix) for prefix in ip_adapter_keys):
|
||||
diffusers_name = (
|
||||
".".join(key.split(".")[1:])
|
||||
.replace("ip_adapter_double_stream_k_proj", "to_k_ip")
|
||||
.replace("ip_adapter_double_stream_v_proj", "to_v_ip")
|
||||
.replace("processor.", "")
|
||||
)
|
||||
state_dict["ip_adapter"][diffusers_name] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if keys != ["image_proj", "ip_adapter"]:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
# load CLIP image encoder here if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_pretrained_model_name_or_path is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {image_encoder_pretrained_model_name_or_path}")
|
||||
image_encoder = (
|
||||
CLIPVisionModelWithProjection.from_pretrained(
|
||||
image_encoder_pretrained_model_name_or_path,
|
||||
subfolder=image_encoder_subfolder,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
dtype=image_encoder_dtype,
|
||||
)
|
||||
.to(self.device)
|
||||
.eval()
|
||||
)
|
||||
self.register_modules(image_encoder=image_encoder)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# create feature extractor if it has not been registered to the pipeline yet
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
|
||||
# FaceID IP adapters don't need the image encoder so it's not present, in this case we default to 224
|
||||
default_clip_size = 224
|
||||
clip_image_size = (
|
||||
self.image_encoder.config.image_size if self.image_encoder is not None else default_clip_size
|
||||
)
|
||||
feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
|
||||
self.register_modules(feature_extractor=feature_extractor)
|
||||
|
||||
# load ip-adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: Union[float, List[float], List[List[float]]]):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a list.
|
||||
|
||||
`float` is converted to list and repeated for the number of blocks and the number of IP adapters. `List[float]`
|
||||
length match the number of blocks, it is repeated for each IP adapter. `List[List[float]]` must match the
|
||||
number of IP adapters and each must match the number of blocks.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
|
||||
def LinearStrengthModel(start, finish, size):
|
||||
return [(start + (finish - start) * (i / (size - 1))) for i in range(size)]
|
||||
|
||||
|
||||
ip_strengths = LinearStrengthModel(0.3, 0.92, 19)
|
||||
pipeline.set_ip_adapter_scale(ip_strengths)
|
||||
```
|
||||
"""
|
||||
|
||||
scale_type = Union[int, float]
|
||||
num_ip_adapters = self.transformer.encoder_hid_proj.num_ip_adapters
|
||||
num_layers = self.transformer.config.num_layers
|
||||
|
||||
# Single value for all layers of all IP-Adapters
|
||||
if isinstance(scale, scale_type):
|
||||
scale = [scale for _ in range(num_ip_adapters)]
|
||||
# List of per-layer scales for a single IP-Adapter
|
||||
elif _is_valid_type(scale, List[scale_type]) and num_ip_adapters == 1:
|
||||
scale = [scale]
|
||||
# Invalid scale type
|
||||
elif not _is_valid_type(scale, List[Union[scale_type, List[scale_type]]]):
|
||||
raise TypeError(f"Unexpected type {_get_detailed_type(scale)} for scale.")
|
||||
|
||||
if len(scale) != num_ip_adapters:
|
||||
raise ValueError(f"Cannot assign {len(scale)} scales to {num_ip_adapters} IP-Adapters.")
|
||||
|
||||
if any(len(s) != num_layers for s in scale if isinstance(s, list)):
|
||||
invalid_scale_sizes = {len(s) for s in scale if isinstance(s, list)} - {num_layers}
|
||||
raise ValueError(
|
||||
f"Expected list of {num_layers} scales, got {', '.join(str(x) for x in invalid_scale_sizes)}."
|
||||
)
|
||||
|
||||
# Scalars are transformed to lists with length num_layers
|
||||
scale_configs = [[s] * num_layers if isinstance(s, scale_type) else s for s in scale]
|
||||
|
||||
# Set scales. zip over scale_configs prevents going into single transformer layers
|
||||
for attn_processor, *scale in zip(self.transformer.attn_processors.values(), *scale_configs):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# remove CLIP image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=[None, None])
|
||||
|
||||
# remove feature extractor only when safety_checker is None as safety_checker uses
|
||||
# the feature_extractor later
|
||||
if not hasattr(self, "safety_checker"):
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=[None, None])
|
||||
|
||||
# remove hidden encoder
|
||||
self.transformer.encoder_hid_proj = None
|
||||
self.transformer.config.encoder_hid_dim_type = None
|
||||
|
||||
# restore original Transformer attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.transformer.attn_processors.items():
|
||||
attn_processor_class = FluxAttnProcessor2_0()
|
||||
attn_procs[name] = (
|
||||
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
|
||||
)
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class SD3IPAdapterMixin:
|
||||
"""Mixin for handling StableDiffusion 3 IP Adapters."""
|
||||
|
||||
@property
|
||||
def is_ip_adapter_active(self) -> bool:
|
||||
"""Checks if IP-Adapter is loaded and scale > 0.
|
||||
|
||||
IP-Adapter scale controls the influence of the image prompt versus text prompt. When this value is set to 0,
|
||||
the image context is irrelevant.
|
||||
|
||||
Returns:
|
||||
`bool`: True when IP-Adapter is loaded and any layer has scale > 0.
|
||||
"""
|
||||
scales = [
|
||||
attn_proc.scale
|
||||
for attn_proc in self.transformer.attn_processors.values()
|
||||
if isinstance(attn_proc, SD3IPAdapterJointAttnProcessor2_0)
|
||||
]
|
||||
|
||||
return len(scales) > 0 and any(scale > 0 for scale in scales)
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
weight_name: str = "ip-adapter.safetensors",
|
||||
subfolder: Optional[str] = None,
|
||||
image_encoder_folder: Optional[str] = "image_encoder",
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
weight_name (`str`, defaults to "ip-adapter.safetensors"):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
|
||||
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
|
||||
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside
|
||||
`subfolder`, you only need to pass the name of the folder that contains image encoder weights, e.g.
|
||||
`image_encoder_folder="image_encoder"`. If the image encoder is located in a folder other than
|
||||
`subfolder`, you should pass the path to the folder that contains image encoder weights, for example,
|
||||
`image_encoder_folder="different_subfolder/image_encoder"`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
# Load the main state dict first
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
# Load image_encoder and feature_extractor here if they haven't been registered to the pipeline yet
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is None:
|
||||
if image_encoder_folder is not None:
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
logger.info(f"loading image_encoder from {pretrained_model_name_or_path_or_dict}")
|
||||
if image_encoder_folder.count("/") == 0:
|
||||
image_encoder_subfolder = Path(subfolder, image_encoder_folder).as_posix()
|
||||
else:
|
||||
image_encoder_subfolder = Path(image_encoder_folder).as_posix()
|
||||
|
||||
# Commons args for loading image encoder and image processor
|
||||
kwargs = {
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"cache_dir": cache_dir,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
|
||||
self.register_modules(
|
||||
feature_extractor=SiglipImageProcessor.from_pretrained(image_encoder_subfolder, **kwargs),
|
||||
image_encoder=SiglipVisionModel.from_pretrained(
|
||||
image_encoder_subfolder, torch_dtype=self.dtype, **kwargs
|
||||
).to(self.device),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"`image_encoder` cannot be loaded because `pretrained_model_name_or_path_or_dict` is a state dict."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter."
|
||||
"Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead."
|
||||
)
|
||||
|
||||
# Load IP-Adapter into transformer
|
||||
self.transformer._load_ip_adapter_weights(state_dict, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
def set_ip_adapter_scale(self, scale: float) -> None:
|
||||
"""
|
||||
Set IP-Adapter scale, which controls image prompt conditioning. A value of 1.0 means the model is only
|
||||
conditioned on the image prompt, and 0.0 only conditioned by the text prompt. Lowering this value encourages
|
||||
the model to produce more diverse images, but they may not be as aligned with the image prompt.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.set_ip_adapter_scale(0.6)
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
scale (float):
|
||||
IP-Adapter scale to be set.
|
||||
|
||||
"""
|
||||
for attn_processor in self.transformer.attn_processors.values():
|
||||
if isinstance(attn_processor, SD3IPAdapterJointAttnProcessor2_0):
|
||||
attn_processor.scale = scale
|
||||
|
||||
def unload_ip_adapter(self) -> None:
|
||||
"""
|
||||
Unloads the IP Adapter weights.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
# Remove image encoder
|
||||
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
|
||||
self.image_encoder = None
|
||||
self.register_to_config(image_encoder=None)
|
||||
|
||||
# Remove feature extractor
|
||||
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is not None:
|
||||
self.feature_extractor = None
|
||||
self.register_to_config(feature_extractor=None)
|
||||
|
||||
# Remove image projection
|
||||
self.transformer.image_proj = None
|
||||
|
||||
# Restore original attention processors layers
|
||||
attn_procs = {
|
||||
name: (
|
||||
JointAttnProcessor2_0() if isinstance(value, SD3IPAdapterJointAttnProcessor2_0) else value.__class__()
|
||||
)
|
||||
for name, value in self.transformer.attn_processors.items()
|
||||
}
|
||||
self.transformer.set_attn_processor(attn_procs)
|
||||
168
src/diffusers/loaders/ip_adapter/transformer_flux.py
Normal file
168
src/diffusers/loaders/ip_adapter/transformer_flux.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ...models.embeddings import ImageProjection, MultiIPAdapterImageProjection
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ...models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
170
src/diffusers/loaders/ip_adapter/transformer_sd3.py
Normal file
170
src/diffusers/loaders/ip_adapter/transformer_sd3.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict
|
||||
|
||||
from ...models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ...models.embeddings import IPAdapterTimeImageProjection
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ...utils import is_accelerate_available, is_torch_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> Dict:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict.items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor & load state_dict
|
||||
attn_procs = {}
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
with init_context():
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> IPAdapterTimeImageProjection:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
# Convert to diffusers
|
||||
updated_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# InstantX/SD3.5-Large-IP-Adapter
|
||||
if key.startswith("layers."):
|
||||
idx = key.split(".")[1]
|
||||
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
||||
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
||||
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
||||
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
||||
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
||||
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
||||
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
||||
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
||||
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
||||
updated_state_dict[key] = value
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
||||
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
||||
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = updated_state_dict["latents"].shape[1]
|
||||
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
with init_context():
|
||||
image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_proj
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
||||
25
src/diffusers/loaders/lora/__init__.py
Normal file
25
src/diffusers/loaders/lora/__init__.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from ...utils import is_peft_available, is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .lora_base import LoraBaseMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .lora_pipeline import (
|
||||
AmusedLoraLoaderMixin,
|
||||
AuraFlowLoraLoaderMixin,
|
||||
CogVideoXLoraLoaderMixin,
|
||||
CogView4LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
SanaLoraLoaderMixin,
|
||||
SD3LoraLoaderMixin,
|
||||
StableDiffusionLoraLoaderMixin,
|
||||
StableDiffusionXLLoraLoaderMixin,
|
||||
WanLoraLoaderMixin,
|
||||
)
|
||||
935
src/diffusers/loaders/lora/lora_base.py
Normal file
935
src/diffusers/loaders/lora/lora_base.py
Normal file
@@ -0,0 +1,935 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ...models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ...models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
Fuses LoRAs for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
"""
|
||||
Unfuses LoRAs for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
|
||||
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
|
||||
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=True)
|
||||
|
||||
|
||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
|
||||
# Safe prefix to check with.
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
_lora_loadable_modules = []
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(cls, **kwargs):
|
||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
_remove_text_encoder_monkey_patch(model)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if "fuse_unet" in kwargs:
|
||||
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_transformer" in kwargs:
|
||||
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
# handle transformers models.
|
||||
if issubclass(model.__class__, PreTrainedModel):
|
||||
fuse_text_encoder_lora(
|
||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
self.num_fused_loras += 1
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if "unfuse_unet" in kwargs:
|
||||
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_transformer" in kwargs:
|
||||
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
if isinstance(adapter_weights, dict):
|
||||
components_passed = set(adapter_weights.keys())
|
||||
lora_components = set(self._lora_loadable_modules)
|
||||
|
||||
invalid_components = sorted(components_passed - lora_components)
|
||||
if invalid_components:
|
||||
logger.warning(
|
||||
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
||||
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
||||
"to the invalid components will be removed and ignored."
|
||||
)
|
||||
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
||||
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
||||
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
component_adapter_weights = weights
|
||||
|
||||
_component_adapter_weights.setdefault(component, [])
|
||||
_component_adapter_weights[component].append(component_adapter_weights)
|
||||
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.disable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
disable_lora_for_text_encoder(model)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.enable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.delete_adapters(adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if (
|
||||
model is not None
|
||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||
and hasattr(model, "peft_config")
|
||||
):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
module.lora_A[adapter_name].to(device)
|
||||
module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
||||
if adapter_name in module.lora_magnitude_vector:
|
||||
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
@@ -17,7 +17,7 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import is_peft_version, logging, state_dict_all_zero
|
||||
from ...utils import is_peft_version, logging, state_dict_all_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
5686
src/diffusers/loaders/lora/lora_pipeline.py
Normal file
5686
src/diffusers/loaders/lora/lora_pipeline.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -12,924 +12,66 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import safetensors
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import model_info
|
||||
from huggingface_hub.constants import HF_HUB_OFFLINE
|
||||
|
||||
from ..models.modeling_utils import ModelMixin, load_state_dict
|
||||
from ..utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
delete_adapter_layers,
|
||||
deprecate,
|
||||
get_adapter_name,
|
||||
get_peft_kwargs,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
is_peft_version,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
logging,
|
||||
recurse_remove_peft_layers,
|
||||
scale_lora_layers,
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from ..models.lora import text_encoder_attn_modules, text_encoder_mlp_modules
|
||||
|
||||
if is_peft_available():
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
from ..utils import deprecate
|
||||
from .lora.lora_base import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, LoraBaseMixin # noqa: F401
|
||||
|
||||
|
||||
def fuse_text_encoder_lora(text_encoder, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
||||
"""
|
||||
Fuses LoRAs for the text encoder.
|
||||
from .lora.lora_base import fuse_text_encoder_lora
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
"""
|
||||
merge_kwargs = {"safe_merge": safe_fusing}
|
||||
deprecation_message = "Importing `fuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import fuse_text_encoder_lora` instead."
|
||||
deprecate("diffusers.loaders.lora_base.fuse_text_encoder_lora", "0.36", deprecation_message)
|
||||
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
if lora_scale != 1.0:
|
||||
module.scale_layer(lora_scale)
|
||||
|
||||
# For BC with previous PEFT versions, we need to check the signature
|
||||
# of the `merge` method to see if it supports the `adapter_names` argument.
|
||||
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
||||
if "adapter_names" in supported_merge_kwargs:
|
||||
merge_kwargs["adapter_names"] = adapter_names
|
||||
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
||||
raise ValueError(
|
||||
"The `adapter_names` argument is not supported with your PEFT version. "
|
||||
"Please upgrade to the latest version of PEFT. `pip install -U peft`"
|
||||
)
|
||||
|
||||
module.merge(**merge_kwargs)
|
||||
return fuse_text_encoder_lora(
|
||||
text_encoder, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
|
||||
def unfuse_text_encoder_lora(text_encoder):
|
||||
"""
|
||||
Unfuses LoRAs for the text encoder.
|
||||
from .lora.lora_base import unfuse_text_encoder_lora
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
for module in text_encoder.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
deprecation_message = "Importing `unfuse_text_encoder_lora()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import unfuse_text_encoder_lora` instead."
|
||||
deprecate("diffusers.loaders.lora_base.unfuse_text_encoder_lora", "0.36", deprecation_message)
|
||||
|
||||
return unfuse_text_encoder_lora(text_encoder)
|
||||
|
||||
|
||||
def set_adapters_for_text_encoder(
|
||||
adapter_names: Union[List[str], str],
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
text_encoder_weights: Optional[Union[float, List[float], List[None]]] = None,
|
||||
adapter_names,
|
||||
text_encoder=None,
|
||||
text_encoder_weights=None,
|
||||
):
|
||||
"""
|
||||
Sets the adapter layers for the text encoder.
|
||||
from .lora.lora_base import set_adapters_for_text_encoder
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]` or `str`):
|
||||
The names of the adapters to use.
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to set the adapter layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
text_encoder_weights (`List[float]`, *optional*):
|
||||
The weights to use for the text encoder. If `None`, the weights are set to `1.0` for all the adapters.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
"The pipeline does not have a default `pipe.text_encoder` class. Please make sure to pass a `text_encoder` instead."
|
||||
)
|
||||
deprecation_message = "Importing `set_adapters_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import set_adapters_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.set_adapters_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
def process_weights(adapter_names, weights):
|
||||
# Expand weights into a list, one entry per adapter
|
||||
# e.g. for 2 adapters: 7 -> [7,7] ; [3, None] -> [3, None]
|
||||
if not isinstance(weights, list):
|
||||
weights = [weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(weights)}"
|
||||
)
|
||||
|
||||
# Set None values to default of 1.0
|
||||
# e.g. [7,7] -> [7,7] ; [3, None] -> [3,1]
|
||||
weights = [w if w is not None else 1.0 for w in weights]
|
||||
|
||||
return weights
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
text_encoder_weights = process_weights(adapter_names, text_encoder_weights)
|
||||
set_weights_and_activate_adapters(text_encoder, adapter_names, text_encoder_weights)
|
||||
|
||||
|
||||
def disable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Disables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to disable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=False)
|
||||
|
||||
|
||||
def enable_lora_for_text_encoder(text_encoder: Optional["PreTrainedModel"] = None):
|
||||
"""
|
||||
Enables the LoRA layers for the text encoder.
|
||||
|
||||
Args:
|
||||
text_encoder (`torch.nn.Module`, *optional*):
|
||||
The text encoder module to enable the LoRA layers for. If `None`, it will try to get the `text_encoder`
|
||||
attribute.
|
||||
"""
|
||||
if text_encoder is None:
|
||||
raise ValueError("Text Encoder not found.")
|
||||
set_adapter_layers(text_encoder, enabled=True)
|
||||
|
||||
|
||||
def _remove_text_encoder_monkey_patch(text_encoder):
|
||||
recurse_remove_peft_layers(text_encoder)
|
||||
if getattr(text_encoder, "peft_config", None) is not None:
|
||||
del text_encoder.peft_config
|
||||
text_encoder._hf_peft_config_loaded = None
|
||||
|
||||
|
||||
def _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weight_name,
|
||||
use_safetensors,
|
||||
local_files_only,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
token,
|
||||
revision,
|
||||
subfolder,
|
||||
user_agent,
|
||||
allow_pickle,
|
||||
):
|
||||
model_file = None
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
# Here we're relaxing the loading check to enable more Inference API
|
||||
# friendliness where sometimes, it's not at all possible to automatically
|
||||
# determine `weight_name`.
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
file_extension=".safetensors",
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except (IOError, safetensors.SafetensorError) as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
# try loading non-safetensors weights
|
||||
model_file = None
|
||||
pass
|
||||
|
||||
if model_file is None:
|
||||
if weight_name is None:
|
||||
weight_name = _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".bin", local_files_only=local_files_only
|
||||
)
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name or LORA_WEIGHT_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def _best_guess_weight_name(
|
||||
pretrained_model_name_or_path_or_dict, file_extension=".safetensors", local_files_only=False
|
||||
):
|
||||
if local_files_only or HF_HUB_OFFLINE:
|
||||
raise ValueError("When using the offline mode, you must specify a `weight_name`.")
|
||||
|
||||
targeted_files = []
|
||||
|
||||
if os.path.isfile(pretrained_model_name_or_path_or_dict):
|
||||
return
|
||||
elif os.path.isdir(pretrained_model_name_or_path_or_dict):
|
||||
targeted_files = [f for f in os.listdir(pretrained_model_name_or_path_or_dict) if f.endswith(file_extension)]
|
||||
else:
|
||||
files_in_repo = model_info(pretrained_model_name_or_path_or_dict).siblings
|
||||
targeted_files = [f.rfilename for f in files_in_repo if f.rfilename.endswith(file_extension)]
|
||||
if len(targeted_files) == 0:
|
||||
return
|
||||
|
||||
# "scheduler" does not correspond to a LoRA checkpoint.
|
||||
# "optimizer" does not correspond to a LoRA checkpoint
|
||||
# only top-level checkpoints are considered and not the other ones, hence "checkpoint".
|
||||
unallowed_substrings = {"scheduler", "optimizer", "checkpoint"}
|
||||
targeted_files = list(
|
||||
filter(lambda x: all(substring not in x for substring in unallowed_substrings), targeted_files)
|
||||
return set_adapters_for_text_encoder(
|
||||
adapter_names=adapter_names, text_encoder=text_encoder, text_encoder_weights=text_encoder_weights
|
||||
)
|
||||
|
||||
if any(f.endswith(LORA_WEIGHT_NAME) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME), targeted_files))
|
||||
elif any(f.endswith(LORA_WEIGHT_NAME_SAFE) for f in targeted_files):
|
||||
targeted_files = list(filter(lambda x: x.endswith(LORA_WEIGHT_NAME_SAFE), targeted_files))
|
||||
|
||||
if len(targeted_files) > 1:
|
||||
raise ValueError(
|
||||
f"Provided path contains more than one weights file in the {file_extension} format. Either specify `weight_name` in `load_lora_weights` or make sure there's only one `.safetensors` or `.bin` file in {pretrained_model_name_or_path_or_dict}."
|
||||
)
|
||||
weight_name = targeted_files[0]
|
||||
return weight_name
|
||||
def disable_lora_for_text_encoder(text_encoder=None):
|
||||
from .lora.lora_base import disable_lora_for_text_encoder
|
||||
|
||||
deprecation_message = "Importing `disable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import disable_lora_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.disable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
def _load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
text_encoder_name="text_encoder",
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
return disable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||
|
||||
peft_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
if not is_peft_version(">=", "0.13.1"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
if not is_transformers_version(">", "4.45.2"):
|
||||
# Note from sayakpaul: It's not in `transformers` stable yet.
|
||||
# https://github.com/huggingface/transformers/pull/33725/
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `transformers` version. Please update it with `pip install -U transformers`."
|
||||
)
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
|
||||
from peft import LoraConfig
|
||||
def enable_lora_for_text_encoder(text_encoder=None):
|
||||
from .lora.lora_base import enable_lora_for_text_encoder
|
||||
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `unet_name` and/or `text_encoder_name` as
|
||||
# their prefixes.
|
||||
prefix = text_encoder_name if prefix is None else prefix
|
||||
deprecation_message = "Importing `enable_lora_for_text_encoder()` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import enable_lora_for_text_encoder` instead."
|
||||
deprecate("diffusers.loaders.lora_base.enable_lora_for_text_encoder", "0.36", deprecation_message)
|
||||
|
||||
# Safe prefix to check with.
|
||||
if hotswap and any(text_encoder_name in key for key in state_dict.keys()):
|
||||
raise ValueError("At the moment, hotswapping is not supported for text encoders, please pass `hotswap=False`.")
|
||||
return enable_lora_for_text_encoder(text_encoder=text_encoder)
|
||||
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
if prefix is not None:
|
||||
state_dict = {k[len(f"{prefix}.") :]: v for k, v in state_dict.items() if k.startswith(f"{prefix}.")}
|
||||
|
||||
if len(state_dict) > 0:
|
||||
logger.info(f"Loading {prefix}.")
|
||||
rank = {}
|
||||
state_dict = convert_state_dict_to_diffusers(state_dict)
|
||||
|
||||
# convert state dict
|
||||
state_dict = convert_state_dict_to_peft(state_dict)
|
||||
|
||||
for name, _ in text_encoder_attn_modules(text_encoder):
|
||||
for module in ("out_proj", "q_proj", "k_proj", "v_proj"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
for name, _ in text_encoder_mlp_modules(text_encoder):
|
||||
for module in ("fc1", "fc2"):
|
||||
rank_key = f"{name}.{module}.lora_B.weight"
|
||||
if rank_key not in state_dict:
|
||||
continue
|
||||
rank[rank_key] = state_dict[rank_key].shape[1]
|
||||
|
||||
if network_alphas is not None:
|
||||
alpha_keys = [k for k in network_alphas.keys() if k.startswith(prefix) and k.split(".")[0] == prefix]
|
||||
network_alphas = {k.replace(f"{prefix}.", ""): v for k, v in network_alphas.items() if k in alpha_keys}
|
||||
|
||||
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=False)
|
||||
|
||||
if "use_dora" in lora_config_kwargs:
|
||||
if lora_config_kwargs["use_dora"]:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<", "0.9.0"):
|
||||
lora_config_kwargs.pop("use_dora")
|
||||
|
||||
if "lora_bias" in lora_config_kwargs:
|
||||
if lora_config_kwargs["lora_bias"]:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
raise ValueError(
|
||||
"You need `peft` 0.14.0 at least to use `bias` in LoRAs. Please upgrade your installation of `peft`."
|
||||
)
|
||||
else:
|
||||
if is_peft_version("<=", "0.13.2"):
|
||||
lora_config_kwargs.pop("lora_bias")
|
||||
|
||||
lora_config = LoraConfig(**lora_config_kwargs)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
adapter_name=adapter_name,
|
||||
adapter_state_dict=state_dict,
|
||||
peft_config=lora_config,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
# scale LoRA layers with `lora_scale`
|
||||
scale_lora_layers(text_encoder, weight=lora_scale)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
logger.warning(
|
||||
f"No LoRA keys associated to {text_encoder.__class__.__name__} found with the {prefix=}. "
|
||||
"This is safe to ignore if LoRA state dict didn't originally have any "
|
||||
f"{text_encoder.__class__.__name__} related params. You can also try specifying `prefix=None` "
|
||||
"to resolve the warning. Otherwise, open an issue if you think it's unexpected: "
|
||||
"https://github.com/huggingface/diffusers/issues/new"
|
||||
)
|
||||
|
||||
|
||||
def _func_optionally_disable_offloading(_pipeline):
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
"""Utility class for handling LoRAs."""
|
||||
|
||||
_lora_loadable_modules = []
|
||||
num_fused_loras = 0
|
||||
|
||||
def load_lora_weights(self, **kwargs):
|
||||
raise NotImplementedError("`load_lora_weights()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(cls, **kwargs):
|
||||
raise NotImplementedError("`save_lora_weights()` not implemented.")
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(cls, **kwargs):
|
||||
raise NotImplementedError("`lora_state_dict()` is not implemented.")
|
||||
|
||||
@classmethod
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
Args:
|
||||
_pipeline (`DiffusionPipeline`):
|
||||
The pipeline to disable offloading for.
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
"""
|
||||
return _func_optionally_disable_offloading(_pipeline=_pipeline)
|
||||
|
||||
@classmethod
|
||||
def _fetch_state_dict(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_fetch_state_dict()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _fetch_state_dict`."
|
||||
deprecate("_fetch_state_dict", "0.35.0", deprecation_message)
|
||||
return _fetch_state_dict(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _best_guess_weight_name(cls, *args, **kwargs):
|
||||
deprecation_message = f"Using the `_best_guess_weight_name()` method from {cls} has been deprecated and will be removed in a future version. Please use `from diffusers.loaders.lora_base import _best_guess_weight_name`."
|
||||
deprecate("_best_guess_weight_name", "0.35.0", deprecation_message)
|
||||
return _best_guess_weight_name(*args, **kwargs)
|
||||
|
||||
def unload_lora_weights(self):
|
||||
"""
|
||||
Unloads the LoRA parameters.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the LoRA parameters.
|
||||
>>> pipeline.unload_lora_weights()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.unload_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
_remove_text_encoder_monkey_patch(model)
|
||||
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = [],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Fuses the LoRA parameters into the original parameters of the corresponding blocks.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
|
||||
lora_scale (`float`, defaults to 1.0):
|
||||
Controls how much to influence the outputs with the LoRA parameters.
|
||||
safe_fusing (`bool`, defaults to `False`):
|
||||
Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
|
||||
adapter_names (`List[str]`, *optional*):
|
||||
Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
|
||||
pipeline.fuse_lora(lora_scale=0.7)
|
||||
```
|
||||
"""
|
||||
if "fuse_unet" in kwargs:
|
||||
depr_message = "Passing `fuse_unet` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_transformer" in kwargs:
|
||||
depr_message = "Passing `fuse_transformer` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "fuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `fuse_text_encoder` to `fuse_lora()` is deprecated and will be ignored. Please use the `components` argument and provide a list of the components whose LoRAs are to be fused. `fuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"fuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
# check if diffusers model
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.fuse_lora(lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names)
|
||||
# handle transformers models.
|
||||
if issubclass(model.__class__, PreTrainedModel):
|
||||
fuse_text_encoder_lora(
|
||||
model, lora_scale=lora_scale, safe_fusing=safe_fusing, adapter_names=adapter_names
|
||||
)
|
||||
|
||||
self.num_fused_loras += 1
|
||||
|
||||
def unfuse_lora(self, components: List[str] = [], **kwargs):
|
||||
r"""
|
||||
Reverses the effect of
|
||||
[`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This is an experimental API.
|
||||
|
||||
</Tip>
|
||||
|
||||
Args:
|
||||
components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
|
||||
unfuse_unet (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
|
||||
unfuse_text_encoder (`bool`, defaults to `True`):
|
||||
Whether to unfuse the text encoder LoRA parameters. If the text encoder wasn't monkey-patched with the
|
||||
LoRA parameters then it won't have any effect.
|
||||
"""
|
||||
if "unfuse_unet" in kwargs:
|
||||
depr_message = "Passing `unfuse_unet` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_unet` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_unet",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_transformer" in kwargs:
|
||||
depr_message = "Passing `unfuse_transformer` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_transformer` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_transformer",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
if "unfuse_text_encoder" in kwargs:
|
||||
depr_message = "Passing `unfuse_text_encoder` to `unfuse_lora()` is deprecated and will be ignored. Please use the `components` argument. `unfuse_text_encoder` will be removed in a future version."
|
||||
deprecate(
|
||||
"unfuse_text_encoder",
|
||||
"1.0.0",
|
||||
depr_message,
|
||||
)
|
||||
|
||||
if len(components) == 0:
|
||||
raise ValueError("`components` cannot be an empty list.")
|
||||
|
||||
for fuse_component in components:
|
||||
if fuse_component not in self._lora_loadable_modules:
|
||||
raise ValueError(f"{fuse_component} is not found in {self._lora_loadable_modules=}.")
|
||||
|
||||
model = getattr(self, fuse_component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, (ModelMixin, PreTrainedModel)):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.unmerge()
|
||||
|
||||
self.num_fused_loras -= 1
|
||||
|
||||
def set_adapters(
|
||||
self,
|
||||
adapter_names: Union[List[str], str],
|
||||
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
|
||||
):
|
||||
if isinstance(adapter_weights, dict):
|
||||
components_passed = set(adapter_weights.keys())
|
||||
lora_components = set(self._lora_loadable_modules)
|
||||
|
||||
invalid_components = sorted(components_passed - lora_components)
|
||||
if invalid_components:
|
||||
logger.warning(
|
||||
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
|
||||
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
|
||||
"to the invalid components will be removed and ignored."
|
||||
)
|
||||
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
|
||||
|
||||
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
|
||||
adapter_weights = copy.deepcopy(adapter_weights)
|
||||
|
||||
# Expand weights into a list, one entry per adapter
|
||||
if not isinstance(adapter_weights, list):
|
||||
adapter_weights = [adapter_weights] * len(adapter_names)
|
||||
|
||||
if len(adapter_names) != len(adapter_weights):
|
||||
raise ValueError(
|
||||
f"Length of adapter names {len(adapter_names)} is not equal to the length of the weights {len(adapter_weights)}"
|
||||
)
|
||||
|
||||
list_adapters = self.get_list_adapters() # eg {"unet": ["adapter1", "adapter2"], "text_encoder": ["adapter2"]}
|
||||
# eg ["adapter1", "adapter2"]
|
||||
all_adapters = {adapter for adapters in list_adapters.values() for adapter in adapters}
|
||||
missing_adapters = set(adapter_names) - all_adapters
|
||||
if len(missing_adapters) > 0:
|
||||
raise ValueError(
|
||||
f"Adapter name(s) {missing_adapters} not in the list of present adapters: {all_adapters}."
|
||||
)
|
||||
|
||||
# eg {"adapter1": ["unet"], "adapter2": ["unet", "text_encoder"]}
|
||||
invert_list_adapters = {
|
||||
adapter: [part for part, adapters in list_adapters.items() if adapter in adapters]
|
||||
for adapter in all_adapters
|
||||
}
|
||||
|
||||
# Decompose weights into weights for denoiser and text encoders.
|
||||
_component_adapter_weights = {}
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component)
|
||||
|
||||
for adapter_name, weights in zip(adapter_names, adapter_weights):
|
||||
if isinstance(weights, dict):
|
||||
component_adapter_weights = weights.pop(component, None)
|
||||
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
|
||||
logger.warning(
|
||||
(
|
||||
f"Lora weight dict for adapter '{adapter_name}' contains {component},"
|
||||
f"but this will be ignored because {adapter_name} does not contain weights for {component}."
|
||||
f"Valid parts for {adapter_name} are: {invert_list_adapters[adapter_name]}."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
component_adapter_weights = weights
|
||||
|
||||
_component_adapter_weights.setdefault(component, [])
|
||||
_component_adapter_weights[component].append(component_adapter_weights)
|
||||
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.set_adapters(adapter_names, _component_adapter_weights[component])
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
set_adapters_for_text_encoder(adapter_names, model, _component_adapter_weights[component])
|
||||
|
||||
def disable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.disable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
disable_lora_for_text_encoder(model)
|
||||
|
||||
def enable_lora(self):
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.enable_lora()
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
enable_lora_for_text_encoder(model)
|
||||
|
||||
def delete_adapters(self, adapter_names: Union[List[str], str]):
|
||||
"""
|
||||
Args:
|
||||
Deletes the LoRA layers of `adapter_name` for the unet and text-encoder(s).
|
||||
adapter_names (`Union[List[str], str]`):
|
||||
The names of the adapter to delete. Can be a single string or a list of strings
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
if isinstance(adapter_names, str):
|
||||
adapter_names = [adapter_names]
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
if issubclass(model.__class__, ModelMixin):
|
||||
model.delete_adapters(adapter_names)
|
||||
elif issubclass(model.__class__, PreTrainedModel):
|
||||
for adapter_name in adapter_names:
|
||||
delete_adapter_layers(model, adapter_name)
|
||||
|
||||
def get_active_adapters(self) -> List[str]:
|
||||
"""
|
||||
Gets the list of the current active adapters.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy")
|
||||
pipeline.get_active_adapters()
|
||||
```
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
active_adapters = []
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None and issubclass(model.__class__, ModelMixin):
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
active_adapters = module.active_adapters
|
||||
break
|
||||
|
||||
return active_adapters
|
||||
|
||||
def get_list_adapters(self) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Gets the current list of all available adapters in the pipeline.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError(
|
||||
"PEFT backend is required for this method. Please install the latest version of PEFT `pip install -U peft`"
|
||||
)
|
||||
|
||||
set_adapters = {}
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if (
|
||||
model is not None
|
||||
and issubclass(model.__class__, (ModelMixin, PreTrainedModel))
|
||||
and hasattr(model, "peft_config")
|
||||
):
|
||||
set_adapters[component] = list(model.peft_config.keys())
|
||||
|
||||
return set_adapters
|
||||
|
||||
def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, str, int]) -> None:
|
||||
"""
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
device (`Union[torch.device, str, int]`):
|
||||
Device to send the adapters to. Can be either a torch device, a str or an integer.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
for component in self._lora_loadable_modules:
|
||||
model = getattr(self, component, None)
|
||||
if model is not None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
module.lora_A[adapter_name].to(device)
|
||||
module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
if hasattr(module, "lora_magnitude_vector") and module.lora_magnitude_vector is not None:
|
||||
if adapter_name in module.lora_magnitude_vector:
|
||||
module.lora_magnitude_vector[adapter_name] = module.lora_magnitude_vector[
|
||||
adapter_name
|
||||
].to(device)
|
||||
|
||||
@staticmethod
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
|
||||
save_path = Path(save_directory, weight_name).as_posix()
|
||||
save_function(state_dict, save_path)
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
# if _lora_scale has not been set, return 1
|
||||
return self._lora_scale if hasattr(self, "_lora_scale") else 1.0
|
||||
|
||||
def enable_lora_hotswap(self, **kwargs) -> None:
|
||||
"""Enables the possibility to hotswap LoRA adapters.
|
||||
|
||||
Calling this method is only required when hotswapping adapters and if the model is compiled or if the ranks of
|
||||
the loaded adapters differ.
|
||||
|
||||
Args:
|
||||
target_rank (`int`):
|
||||
The highest rank among all the adapters that will be loaded.
|
||||
check_compiled (`str`, *optional*, defaults to `"error"`):
|
||||
How to handle the case when the model is already compiled, which should generally be avoided. The
|
||||
options are:
|
||||
- "error" (default): raise an error
|
||||
- "warn": issue a warning
|
||||
- "ignore": do nothing
|
||||
"""
|
||||
for key, component in self.components.items():
|
||||
if hasattr(component, "enable_lora_hotswap") and (key in self._lora_loadable_modules):
|
||||
component.enable_lora_hotswap(**kwargs)
|
||||
class LoraBaseMixin(LoraBaseMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `LoraBaseMixin` from diffusers.loaders.lora_base has been deprecated. Please use `from diffusers.loaders.lora.lora_base import LoraBaseMixin` instead."
|
||||
deprecate("diffusers.loaders.lora_base.LoraBaseMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -35,8 +35,8 @@ from ..utils import (
|
||||
set_adapter_layers,
|
||||
set_weights_and_activate_adapters,
|
||||
)
|
||||
from .lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet_loader_utils import _maybe_expand_lora_scales
|
||||
from .lora.lora_base import _fetch_state_dict, _func_optionally_disable_offloading
|
||||
from .unet.unet_loader_utils import _maybe_expand_lora_scales
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -99,7 +99,7 @@ class PeftAdapterMixin:
|
||||
_prepare_lora_hotswap_kwargs: Optional[dict] = None
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
|
||||
@@ -11,42 +11,8 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
from ..utils import deprecate
|
||||
from .single_file.single_file import FromSingleFileMixin
|
||||
|
||||
|
||||
def load_single_file_sub_model(
|
||||
@@ -64,502 +30,30 @@ def load_single_file_sub_model(
|
||||
disable_mmap=False,
|
||||
**kwargs,
|
||||
):
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
from .single_file.single_file import load_single_file_sub_model
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
deprecation_message = "Importing `load_single_file_sub_model()` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import load_single_file_sub_model` instead."
|
||||
deprecate("diffusers.loaders.single_file.load_single_file_sub_model", "0.36", deprecation_message)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
return load_single_file_sub_model(
|
||||
library_name,
|
||||
class_name,
|
||||
name,
|
||||
checkpoint,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
cached_model_config_path,
|
||||
original_config,
|
||||
local_files_only,
|
||||
torch_dtype,
|
||||
is_legacy_loading,
|
||||
disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||
|
||||
if is_diffusers_single_file_model:
|
||||
load_method = getattr(class_obj, "from_single_file")
|
||||
|
||||
# We cannot provide two different config options to the `from_single_file` method
|
||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||
if original_config:
|
||||
cached_model_config_path = None
|
||||
|
||||
loaded_sub_model = load_method(
|
||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||
original_config=original_config,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if not hasattr(class_obj, "from_pretrained"):
|
||||
raise ValueError(
|
||||
(
|
||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||
" a supported loading method."
|
||||
)
|
||||
)
|
||||
|
||||
loading_kwargs = {}
|
||||
loading_kwargs.update(
|
||||
{
|
||||
"pretrained_model_name_or_path": cached_model_config_path,
|
||||
"subfolder": name,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
)
|
||||
|
||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||
# Skip passing it to those objects
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, "from_pretrained")
|
||||
loaded_sub_model = load_method(**loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _map_component_types_to_config_dict(component_types):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
config_dict = {}
|
||||
component_types.pop("self", None)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
for component_name, component_value in component_types.items():
|
||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_transformers_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif is_scheduler_enum or is_scheduler:
|
||||
if is_scheduler_enum:
|
||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||
|
||||
elif is_scheduler:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif (
|
||||
is_transformers_model or is_transformers_tokenizer
|
||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||
|
||||
else:
|
||||
config_dict[component_name] = [None, None]
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _infer_pipeline_config_dict(pipeline_class):
|
||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
component_types = pipeline_class._get_signature_types()
|
||||
|
||||
# Ignore parameters that are not required for the pipeline
|
||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||
config_dict = _map_component_types_to_config_dict(component_types)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _download_diffusers_model_config_from_hub(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
|
||||
return cached_model_path
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
config (`str`, *optional*):
|
||||
Can be either:
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if original_config_file is not None:
|
||||
deprecation_message = (
|
||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||
"please use the `original_config` argument instead."
|
||||
)
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
if scaling_factor is not None:
|
||||
deprecation_message = (
|
||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||
"and will be ignored in future versions."
|
||||
)
|
||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||
|
||||
if original_config is not None:
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
else:
|
||||
default_pretrained_model_config_name = config
|
||||
|
||||
if not os.path.isdir(default_pretrained_model_config_name):
|
||||
# Provided config is a repo_id
|
||||
if default_pretrained_model_config_name.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided config "{config}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
try:
|
||||
# Attempt to download the config files for the pipeline
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
except LocalEntryNotFoundError:
|
||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||
# to fetch the config files from the hub so that we have a way
|
||||
# to configure the pipeline components.
|
||||
|
||||
if original_config is None:
|
||||
logger.warning(
|
||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||
"Attempting to download the necessary config files for this pipeline.\n"
|
||||
)
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
else:
|
||||
# For backwards compatibility
|
||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||
logger.warning(
|
||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||
"the necessary config files.\n"
|
||||
)
|
||||
is_legacy_loading = True
|
||||
cached_model_config_path = None
|
||||
|
||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||
config_dict["_class_name"] = pipeline_class.__name__
|
||||
|
||||
else:
|
||||
# Provided config is a path to a local directory attempt to load directly.
|
||||
cached_model_config_path = default_pretrained_model_config_name
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
from diffusers import pipelines
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||
):
|
||||
loaded_sub_model = None
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if name in passed_class_obj:
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
try:
|
||||
loaded_sub_model = load_single_file_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
name=name,
|
||||
checkpoint=checkpoint,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
cached_model_config_path=cached_model_config_path,
|
||||
pipelines=pipelines,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
raise SingleFileComponentError(
|
||||
(
|
||||
f"{e.message}\n"
|
||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||
f"\n"
|
||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||
f"\n"
|
||||
)
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# deprecated kwargs
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||
if load_safety_checker is not None:
|
||||
deprecation_message = (
|
||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||
)
|
||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||
|
||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||
init_kwargs.update(safety_checker_components)
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
return pipe
|
||||
class FromSingleFileMixin(FromSingleFileMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FromSingleFileMixin` from diffusers.loaders.single_file has been deprecated. Please use `from diffusers.loaders.single_file.single_file import FromSingleFileMixin` instead."
|
||||
deprecate("diffusers.loaders.single_file.FromSingleFileMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
8
src/diffusers/loaders/single_file/__init__.py
Normal file
8
src/diffusers/loaders/single_file/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from ...utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .single_file_model import FromOriginalModelMixin
|
||||
|
||||
if is_transformers_available():
|
||||
from .single_file import FromSingleFileMixin
|
||||
565
src/diffusers/loaders/single_file/single_file.py
Normal file
565
src/diffusers/loaders/single_file/single_file.py
Normal file
@@ -0,0 +1,565 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
|
||||
from packaging import version
|
||||
from typing_extensions import Self
|
||||
|
||||
from ...utils import deprecate, is_transformers_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
_is_legacy_scheduler_kwargs,
|
||||
_is_model_weights_in_cached_folder,
|
||||
_legacy_load_clip_tokenizer,
|
||||
_legacy_load_safety_checker,
|
||||
_legacy_load_scheduler,
|
||||
create_diffusers_clip_model_from_ldm,
|
||||
create_diffusers_t5_model_from_checkpoint,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
is_clip_model_in_single_file,
|
||||
is_t5_in_single_file,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# Legacy behaviour. `from_single_file` does not load the safety checker unless explicitly provided
|
||||
SINGLE_FILE_OPTIONAL_COMPONENTS = ["safety_checker"]
|
||||
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
|
||||
def load_single_file_sub_model(
|
||||
library_name,
|
||||
class_name,
|
||||
name,
|
||||
checkpoint,
|
||||
pipelines,
|
||||
is_pipeline_module,
|
||||
cached_model_config_path,
|
||||
original_config=None,
|
||||
local_files_only=False,
|
||||
torch_dtype=None,
|
||||
is_legacy_loading=False,
|
||||
disable_mmap=False,
|
||||
**kwargs,
|
||||
):
|
||||
if is_pipeline_module:
|
||||
pipeline_module = getattr(pipelines, library_name)
|
||||
class_obj = getattr(pipeline_module, class_name)
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(class_obj, PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
is_diffusers_single_file_model = issubclass(class_obj, diffusers_module.FromOriginalModelMixin)
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
is_diffusers_scheduler = issubclass(class_obj, diffusers_module.SchedulerMixin)
|
||||
|
||||
if is_diffusers_single_file_model:
|
||||
load_method = getattr(class_obj, "from_single_file")
|
||||
|
||||
# We cannot provide two different config options to the `from_single_file` method
|
||||
# Here we have to ignore loading the config from `cached_model_config_path` if `original_config` is provided
|
||||
if original_config:
|
||||
cached_model_config_path = None
|
||||
|
||||
loaded_sub_model = load_method(
|
||||
pretrained_model_link_or_path_or_dict=checkpoint,
|
||||
original_config=original_config,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_clip_model_in_single_file(class_obj, checkpoint):
|
||||
loaded_sub_model = create_diffusers_clip_model_from_ldm(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
)
|
||||
|
||||
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
||||
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
||||
class_obj,
|
||||
checkpoint=checkpoint,
|
||||
config=cached_model_config_path,
|
||||
subfolder=name,
|
||||
torch_dtype=torch_dtype,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
elif is_tokenizer and is_legacy_loading:
|
||||
loaded_sub_model = _legacy_load_clip_tokenizer(
|
||||
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
||||
)
|
||||
|
||||
elif is_diffusers_scheduler and (is_legacy_loading or _is_legacy_scheduler_kwargs(kwargs)):
|
||||
loaded_sub_model = _legacy_load_scheduler(
|
||||
class_obj, checkpoint=checkpoint, component_name=name, original_config=original_config, **kwargs
|
||||
)
|
||||
|
||||
else:
|
||||
if not hasattr(class_obj, "from_pretrained"):
|
||||
raise ValueError(
|
||||
(
|
||||
f"The component {class_obj.__name__} cannot be loaded as it does not seem to have"
|
||||
" a supported loading method."
|
||||
)
|
||||
)
|
||||
|
||||
loading_kwargs = {}
|
||||
loading_kwargs.update(
|
||||
{
|
||||
"pretrained_model_name_or_path": cached_model_config_path,
|
||||
"subfolder": name,
|
||||
"local_files_only": local_files_only,
|
||||
}
|
||||
)
|
||||
|
||||
# Schedulers and Tokenizers don't make use of torch_dtype
|
||||
# Skip passing it to those objects
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs.update({"torch_dtype": torch_dtype})
|
||||
|
||||
if is_diffusers_model or is_transformers_model:
|
||||
if not _is_model_weights_in_cached_folder(cached_model_config_path, name):
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
load_method = getattr(class_obj, "from_pretrained")
|
||||
loaded_sub_model = load_method(**loading_kwargs)
|
||||
|
||||
return loaded_sub_model
|
||||
|
||||
|
||||
def _map_component_types_to_config_dict(component_types):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
config_dict = {}
|
||||
component_types.pop("self", None)
|
||||
|
||||
if is_transformers_available():
|
||||
transformers_version = version.parse(version.parse(transformers.__version__).base_version)
|
||||
else:
|
||||
transformers_version = "N/A"
|
||||
|
||||
for component_name, component_value in component_types.items():
|
||||
is_diffusers_model = issubclass(component_value[0], diffusers_module.ModelMixin)
|
||||
is_scheduler_enum = component_value[0].__name__ == "KarrasDiffusionSchedulers"
|
||||
is_scheduler = issubclass(component_value[0], diffusers_module.SchedulerMixin)
|
||||
|
||||
is_transformers_model = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedModel)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
is_transformers_tokenizer = (
|
||||
is_transformers_available()
|
||||
and issubclass(component_value[0], PreTrainedTokenizer)
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
if is_diffusers_model and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif is_scheduler_enum or is_scheduler:
|
||||
if is_scheduler_enum:
|
||||
# Since we cannot fetch a scheduler config from the hub, we default to DDIMScheduler
|
||||
# if the type hint is a KarrassDiffusionSchedulers enum
|
||||
config_dict[component_name] = ["diffusers", "DDIMScheduler"]
|
||||
|
||||
elif is_scheduler:
|
||||
config_dict[component_name] = ["diffusers", component_value[0].__name__]
|
||||
|
||||
elif (
|
||||
is_transformers_model or is_transformers_tokenizer
|
||||
) and component_name not in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
config_dict[component_name] = ["transformers", component_value[0].__name__]
|
||||
|
||||
else:
|
||||
config_dict[component_name] = [None, None]
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _infer_pipeline_config_dict(pipeline_class):
|
||||
parameters = inspect.signature(pipeline_class.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
component_types = pipeline_class._get_signature_types()
|
||||
|
||||
# Ignore parameters that are not required for the pipeline
|
||||
component_types = {k: v for k, v in component_types.items() if k in required_parameters}
|
||||
config_dict = _map_component_types_to_config_dict(component_types)
|
||||
|
||||
return config_dict
|
||||
|
||||
|
||||
def _download_diffusers_model_config_from_hub(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir,
|
||||
revision,
|
||||
proxies,
|
||||
force_download=None,
|
||||
local_files_only=None,
|
||||
token=None,
|
||||
):
|
||||
allow_patterns = ["**/*.json", "*.json", "*.txt", "**/*.txt", "**/*.model"]
|
||||
cached_model_path = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
|
||||
return cached_model_path
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
|
||||
format. The pipeline is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.ckpt"`) on the Hub.
|
||||
- A path to a *file* containing all pipeline weights.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
original_config_file (`str`, *optional*):
|
||||
The path to the original config file that was used to train the model. If not provided, the config file
|
||||
will be inferred from the checkpoint file.
|
||||
config (`str`, *optional*):
|
||||
Can be either:
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline
|
||||
hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline
|
||||
component configs in Diffusers format.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (the pipeline components of the specific pipeline
|
||||
class). The overwritten components are passed directly to the pipelines `__init__` method. See example
|
||||
below for more information.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableDiffusionPipeline
|
||||
|
||||
>>> # Download pipeline from huggingface.co and cache.
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/Models/AbyssOrangeMix/AbyssOrangeMix.safetensors"
|
||||
... )
|
||||
|
||||
>>> # Download pipeline from local file
|
||||
>>> # file is downloaded under ./v1-5-pruned-emaonly.ckpt
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file("./v1-5-pruned-emaonly.ckpt")
|
||||
|
||||
>>> # Enable float16 and move to GPU
|
||||
>>> pipeline = StableDiffusionPipeline.from_single_file(
|
||||
... "https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5/blob/main/v1-5-pruned-emaonly.ckpt",
|
||||
... torch_dtype=torch.float16,
|
||||
... )
|
||||
>>> pipeline.to("cuda")
|
||||
```
|
||||
|
||||
"""
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if original_config_file is not None:
|
||||
deprecation_message = (
|
||||
"`original_config_file` argument is deprecated and will be removed in future versions."
|
||||
"please use the `original_config` argument instead."
|
||||
)
|
||||
deprecate("original_config_file", "1.0.0", deprecation_message)
|
||||
original_config = original_config_file
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
if scaling_factor is not None:
|
||||
deprecation_message = (
|
||||
"Passing the `scaling_factor` argument to `from_single_file is deprecated "
|
||||
"and will be ignored in future versions."
|
||||
)
|
||||
deprecate("scaling_factor", "1.0.0", deprecation_message)
|
||||
|
||||
if original_config is not None:
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
from ..pipelines.pipeline_utils import _get_pipeline_class
|
||||
|
||||
pipeline_class = _get_pipeline_class(cls, config=None)
|
||||
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
if config is None:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
else:
|
||||
default_pretrained_model_config_name = config
|
||||
|
||||
if not os.path.isdir(default_pretrained_model_config_name):
|
||||
# Provided config is a repo_id
|
||||
if default_pretrained_model_config_name.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided config "{config}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
try:
|
||||
# Attempt to download the config files for the pipeline
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
except LocalEntryNotFoundError:
|
||||
# `local_files_only=True` but a local diffusers format model config is not available in the cache
|
||||
# If `original_config` is not provided, we need override `local_files_only` to False
|
||||
# to fetch the config files from the hub so that we have a way
|
||||
# to configure the pipeline components.
|
||||
|
||||
if original_config is None:
|
||||
logger.warning(
|
||||
"`local_files_only` is True but no local configs were found for this checkpoint.\n"
|
||||
"Attempting to download the necessary config files for this pipeline.\n"
|
||||
)
|
||||
cached_model_config_path = _download_diffusers_model_config_from_hub(
|
||||
default_pretrained_model_config_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
proxies=proxies,
|
||||
force_download=force_download,
|
||||
local_files_only=False,
|
||||
token=token,
|
||||
)
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
else:
|
||||
# For backwards compatibility
|
||||
# If `original_config` is provided, then we need to assume we are using legacy loading for pipeline components
|
||||
logger.warning(
|
||||
"Detected legacy `from_single_file` loading behavior. Attempting to create the pipeline based on inferred components.\n"
|
||||
"This may lead to errors if the model components are not correctly inferred. \n"
|
||||
"To avoid this warning, please explicity pass the `config` argument to `from_single_file` with a path to a local diffusers model repo \n"
|
||||
"e.g. `from_single_file(<my model checkpoint path>, config=<path to local diffusers model repo>) \n"
|
||||
"or run `from_single_file` with `local_files_only=False` first to update the local cache directory with "
|
||||
"the necessary config files.\n"
|
||||
)
|
||||
is_legacy_loading = True
|
||||
cached_model_config_path = None
|
||||
|
||||
config_dict = _infer_pipeline_config_dict(pipeline_class)
|
||||
config_dict["_class_name"] = pipeline_class.__name__
|
||||
|
||||
else:
|
||||
# Provided config is a path to a local directory attempt to load directly.
|
||||
cached_model_config_path = default_pretrained_model_config_name
|
||||
config_dict = pipeline_class.load_config(cached_model_config_path)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
expected_modules, optional_kwargs = pipeline_class._get_signature_keys(cls)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
from diffusers import pipelines
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
if name in SINGLE_FILE_OPTIONAL_COMPONENTS:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
sorted(init_dict.items()), desc="Loading pipeline components..."
|
||||
):
|
||||
loaded_sub_model = None
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
|
||||
if name in passed_class_obj:
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
|
||||
else:
|
||||
try:
|
||||
loaded_sub_model = load_single_file_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
name=name,
|
||||
checkpoint=checkpoint,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
cached_model_config_path=cached_model_config_path,
|
||||
pipelines=pipelines,
|
||||
torch_dtype=torch_dtype,
|
||||
original_config=original_config,
|
||||
local_files_only=local_files_only,
|
||||
is_legacy_loading=is_legacy_loading,
|
||||
disable_mmap=disable_mmap,
|
||||
**kwargs,
|
||||
)
|
||||
except SingleFileComponentError as e:
|
||||
raise SingleFileComponentError(
|
||||
(
|
||||
f"{e.message}\n"
|
||||
f"Please load the component before passing it in as an argument to `from_single_file`.\n"
|
||||
f"\n"
|
||||
f"{name} = {class_name}.from_pretrained('...')\n"
|
||||
f"pipe = {pipeline_class.__name__}.from_single_file(<checkpoint path>, {name}={name})\n"
|
||||
f"\n"
|
||||
)
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model
|
||||
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# deprecated kwargs
|
||||
load_safety_checker = kwargs.pop("load_safety_checker", None)
|
||||
if load_safety_checker is not None:
|
||||
deprecation_message = (
|
||||
"Please pass instances of `StableDiffusionSafetyChecker` and `AutoImageProcessor`"
|
||||
"using the `safety_checker` and `feature_extractor` arguments in `from_single_file`"
|
||||
)
|
||||
deprecate("load_safety_checker", "1.0.0", deprecation_message)
|
||||
|
||||
safety_checker_components = _legacy_load_safety_checker(local_files_only, torch_dtype)
|
||||
init_kwargs.update(safety_checker_components)
|
||||
|
||||
pipe = pipeline_class(**init_kwargs)
|
||||
|
||||
return pipe
|
||||
440
src/diffusers/loaders/single_file/single_file_model.py
Normal file
440
src/diffusers/loaders/single_file/single_file_model.py
Normal file
@@ -0,0 +1,440 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from ... import __version__
|
||||
from ...quantizers import DiffusersAutoQuantizer
|
||||
from ...utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
load_single_file_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, init_empty_weights
|
||||
|
||||
from ...models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableCascadeUNet": {
|
||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
},
|
||||
"UNet2DConditionModel": {
|
||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||
"default_subfolder": "unet",
|
||||
"legacy_kwargs": {
|
||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||
},
|
||||
},
|
||||
"AutoencoderKL": {
|
||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"ControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTXVideo": {
|
||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Lumina2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
mapping_kwargs = {}
|
||||
for parameter in parameters:
|
||||
if parameter in kwargs:
|
||||
mapping_kwargs[parameter] = kwargs[parameter]
|
||||
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||
- A path to a local *file* containing the weights of the component model.
|
||||
- A state dict containing the component model weights.
|
||||
config (`str`, *optional*):
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||
on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||
configs in Diffusers format.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableCascadeUNet
|
||||
|
||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
"""
|
||||
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
|
||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||
if pretrained_model_link_or_path is not None:
|
||||
deprecation_message = (
|
||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||
)
|
||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if config is not None and original_config is not None:
|
||||
raise ValueError(
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if quantization_config is not None:
|
||||
user_agent["quant"] = quantization_config.quant_method.value
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config is not None:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
config_mapping_fn = None
|
||||
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||
diffusers_model_config = config_mapping_fn(
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config is not None:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
raise ValueError(
|
||||
(
|
||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||
"or path to a local Diffusers model repo."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
|
||||
if "default_subfolder" in mapping_functions:
|
||||
subfolder = mapping_functions["default_subfolder"]
|
||||
|
||||
subfolder = subfolder or config.pop(
|
||||
"subfolder", None
|
||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||
|
||||
diffusers_model_config = cls.load_config(
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
# Map legacy kwargs to new kwargs
|
||||
if "legacy_kwargs" in mapping_functions:
|
||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||
for legacy_key, new_key in legacy_kwargs.items():
|
||||
if legacy_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
device_map = None
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
return model
|
||||
3295
src/diffusers/loaders/single_file/single_file_utils.py
Normal file
3295
src/diffusers/loaders/single_file/single_file_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,430 +11,17 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from .. import __version__
|
||||
from ..quantizers import DiffusersAutoQuantizer
|
||||
from ..utils import deprecate, is_accelerate_available, logging
|
||||
from .single_file_utils import (
|
||||
SingleFileComponentError,
|
||||
convert_animatediff_checkpoint_to_diffusers,
|
||||
convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
convert_ldm_unet_checkpoint,
|
||||
convert_ldm_vae_checkpoint,
|
||||
convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
convert_ltx_vae_checkpoint_to_diffusers,
|
||||
convert_lumina2_to_diffusers,
|
||||
convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
convert_sana_transformer_to_diffusers,
|
||||
convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
create_vae_diffusers_config_from_ldm,
|
||||
fetch_diffusers_config,
|
||||
fetch_original_config,
|
||||
load_single_file_checkpoint,
|
||||
from ..utils import deprecate
|
||||
from .single_file.single_file_model import (
|
||||
SINGLE_FILE_LOADABLE_CLASSES, # noqa: F401
|
||||
FromOriginalModelMixin,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model, init_empty_weights
|
||||
|
||||
from ..models.modeling_utils import load_model_dict_into_meta
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableCascadeUNet": {
|
||||
"checkpoint_mapping_fn": convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
},
|
||||
"UNet2DConditionModel": {
|
||||
"checkpoint_mapping_fn": convert_ldm_unet_checkpoint,
|
||||
"config_mapping_fn": create_unet_diffusers_config_from_ldm,
|
||||
"default_subfolder": "unet",
|
||||
"legacy_kwargs": {
|
||||
"num_in_channels": "in_channels", # Legacy kwargs supported by `from_single_file` mapped to new args
|
||||
},
|
||||
},
|
||||
"AutoencoderKL": {
|
||||
"checkpoint_mapping_fn": convert_ldm_vae_checkpoint,
|
||||
"config_mapping_fn": create_vae_diffusers_config_from_ldm,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"ControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_controlnet_checkpoint,
|
||||
"config_mapping_fn": create_controlnet_diffusers_config_from_ldm,
|
||||
},
|
||||
"SD3Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sd3_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"MotionAdapter": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"SparseControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_animatediff_checkpoint_to_diffusers,
|
||||
},
|
||||
"FluxTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"LTXVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_ltx_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLLTXVideo": {
|
||||
"checkpoint_mapping_fn": convert_ltx_vae_checkpoint_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
"AutoencoderDC": {"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers},
|
||||
"MochiTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_mochi_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"HunyuanVideoTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_hunyuan_video_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AuraFlowTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Lumina2Transformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"SanaTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": convert_sana_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
for loadable_class_str in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
loadable_class = getattr(diffusers_module, loadable_class_str)
|
||||
|
||||
if issubclass(cls, loadable_class):
|
||||
return loadable_class_str
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _get_mapping_function_kwargs(mapping_fn, **kwargs):
|
||||
parameters = inspect.signature(mapping_fn).parameters
|
||||
|
||||
mapping_kwargs = {}
|
||||
for parameter in parameters:
|
||||
if parameter in kwargs:
|
||||
mapping_kwargs[parameter] = kwargs[parameter]
|
||||
|
||||
return mapping_kwargs
|
||||
|
||||
|
||||
class FromOriginalModelMixin:
|
||||
"""
|
||||
Load pretrained weights saved in the `.ckpt` or `.safetensors` format into a model.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
|
||||
r"""
|
||||
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
|
||||
is set in evaluation mode (`model.eval()`) by default.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_link_or_path_or_dict (`str`, *optional*):
|
||||
Can be either:
|
||||
- A link to the `.safetensors` or `.ckpt` file (for example
|
||||
`"https://huggingface.co/<repo_id>/blob/main/<path_to_file>.safetensors"`) on the Hub.
|
||||
- A path to a local *file* containing the weights of the component model.
|
||||
- A state dict containing the component model weights.
|
||||
config (`str`, *optional*):
|
||||
- A string, the *repo id* (for example `CompVis/ldm-text2im-large-256`) of a pretrained pipeline hosted
|
||||
on the Hub.
|
||||
- A path to a *directory* (for example `./my_pipeline_directory/`) containing the pipeline component
|
||||
configs in Diffusers format.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
original_config (`str`, *optional*):
|
||||
Dict or path to a yaml file containing the configuration for the model in its original format.
|
||||
If a dict is provided, it will be used to initialize the model configuration.
|
||||
torch_dtype (`str` or `torch.dtype`, *optional*):
|
||||
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
||||
dtype is automatically derived from the model's weights.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to True, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
disable_mmap ('bool', *optional*, defaults to 'False'):
|
||||
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
||||
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to overwrite load and saveable variables (for example the pipeline components of the
|
||||
specific pipeline class). The overwritten components are directly passed to the pipelines `__init__`
|
||||
method. See example below for more information.
|
||||
|
||||
```py
|
||||
>>> from diffusers import StableCascadeUNet
|
||||
|
||||
>>> ckpt_path = "https://huggingface.co/stabilityai/stable-cascade/blob/main/stage_b_lite.safetensors"
|
||||
>>> model = StableCascadeUNet.from_single_file(ckpt_path)
|
||||
```
|
||||
"""
|
||||
|
||||
mapping_class_name = _get_single_file_loadable_mapping_class(cls)
|
||||
# if class_name not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
if mapping_class_name is None:
|
||||
raise ValueError(
|
||||
f"FromOriginalModelMixin is currently only compatible with {', '.join(SINGLE_FILE_LOADABLE_CLASSES.keys())}"
|
||||
)
|
||||
|
||||
pretrained_model_link_or_path = kwargs.get("pretrained_model_link_or_path", None)
|
||||
if pretrained_model_link_or_path is not None:
|
||||
deprecation_message = (
|
||||
"Please use `pretrained_model_link_or_path_or_dict` argument instead for model classes"
|
||||
)
|
||||
deprecate("pretrained_model_link_or_path", "1.0.0", deprecation_message)
|
||||
pretrained_model_link_or_path_or_dict = pretrained_model_link_or_path
|
||||
|
||||
config = kwargs.pop("config", None)
|
||||
original_config = kwargs.pop("original_config", None)
|
||||
|
||||
if config is not None and original_config is not None:
|
||||
raise ValueError(
|
||||
"`from_single_file` cannot accept both `config` and `original_config` arguments. Please provide only one of these arguments"
|
||||
)
|
||||
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
token = kwargs.pop("token", None)
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
|
||||
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
||||
if quantization_config is not None:
|
||||
user_agent["quant"] = quantization_config.quant_method.value
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
checkpoint = load_single_file_checkpoint(
|
||||
pretrained_model_link_or_path_or_dict,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
disable_mmap=disable_mmap,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if quantization_config is not None:
|
||||
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
|
||||
hf_quantizer.validate_environment()
|
||||
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
||||
|
||||
else:
|
||||
hf_quantizer = None
|
||||
|
||||
mapping_functions = SINGLE_FILE_LOADABLE_CLASSES[mapping_class_name]
|
||||
|
||||
checkpoint_mapping_fn = mapping_functions["checkpoint_mapping_fn"]
|
||||
if original_config is not None:
|
||||
if "config_mapping_fn" in mapping_functions:
|
||||
config_mapping_fn = mapping_functions["config_mapping_fn"]
|
||||
else:
|
||||
config_mapping_fn = None
|
||||
|
||||
if config_mapping_fn is None:
|
||||
raise ValueError(
|
||||
(
|
||||
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
|
||||
"was found to convert the original config to a Diffusers config in"
|
||||
"`diffusers.loaders.single_file_utils`"
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(original_config, str):
|
||||
# If original_config is a URL or filepath fetch the original_config dict
|
||||
original_config = fetch_original_config(original_config, local_files_only=local_files_only)
|
||||
|
||||
config_mapping_kwargs = _get_mapping_function_kwargs(config_mapping_fn, **kwargs)
|
||||
diffusers_model_config = config_mapping_fn(
|
||||
original_config=original_config, checkpoint=checkpoint, **config_mapping_kwargs
|
||||
)
|
||||
else:
|
||||
if config is not None:
|
||||
if isinstance(config, str):
|
||||
default_pretrained_model_config_name = config
|
||||
else:
|
||||
raise ValueError(
|
||||
(
|
||||
"Invalid `config` argument. Please provide a string representing a repo id"
|
||||
"or path to a local Diffusers model repo."
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
config = fetch_diffusers_config(checkpoint)
|
||||
default_pretrained_model_config_name = config["pretrained_model_name_or_path"]
|
||||
|
||||
if "default_subfolder" in mapping_functions:
|
||||
subfolder = mapping_functions["default_subfolder"]
|
||||
|
||||
subfolder = subfolder or config.pop(
|
||||
"subfolder", None
|
||||
) # some configs contain a subfolder key, e.g. StableCascadeUNet
|
||||
|
||||
diffusers_model_config = cls.load_config(
|
||||
pretrained_model_name_or_path=default_pretrained_model_config_name,
|
||||
subfolder=subfolder,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=config_revision,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = cls._get_signature_keys(cls)
|
||||
|
||||
# Map legacy kwargs to new kwargs
|
||||
if "legacy_kwargs" in mapping_functions:
|
||||
legacy_kwargs = mapping_functions["legacy_kwargs"]
|
||||
for legacy_key, new_key in legacy_kwargs.items():
|
||||
if legacy_key in kwargs:
|
||||
kwargs[new_key] = kwargs.pop(legacy_key)
|
||||
|
||||
model_kwargs = {k: kwargs.get(k) for k in kwargs if k in expected_kwargs or k in optional_kwargs}
|
||||
diffusers_model_config.update(model_kwargs)
|
||||
|
||||
checkpoint_mapping_kwargs = _get_mapping_function_kwargs(checkpoint_mapping_fn, **kwargs)
|
||||
diffusers_format_checkpoint = checkpoint_mapping_fn(
|
||||
config=diffusers_model_config, checkpoint=checkpoint, **checkpoint_mapping_kwargs
|
||||
)
|
||||
if not diffusers_format_checkpoint:
|
||||
raise SingleFileComponentError(
|
||||
f"Failed to load {mapping_class_name}. Weights for this component appear to be missing in the checkpoint."
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
with ctx():
|
||||
model = cls.from_config(diffusers_model_config)
|
||||
|
||||
# Check if `_keep_in_fp32_modules` is not None
|
||||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
||||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
||||
)
|
||||
if use_keep_in_fp32_modules:
|
||||
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
||||
if not isinstance(keep_in_fp32_modules, list):
|
||||
keep_in_fp32_modules = [keep_in_fp32_modules]
|
||||
|
||||
else:
|
||||
keep_in_fp32_modules = []
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.preprocess_model(
|
||||
model=model,
|
||||
device_map=None,
|
||||
state_dict=diffusers_format_checkpoint,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
)
|
||||
|
||||
device_map = None
|
||||
if is_accelerate_available():
|
||||
param_device = torch.device(device) if device else torch.device("cpu")
|
||||
empty_state_dict = model.state_dict()
|
||||
unexpected_keys = [
|
||||
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
|
||||
]
|
||||
device_map = {"": param_device}
|
||||
load_model_dict_into_meta(
|
||||
model,
|
||||
diffusers_format_checkpoint,
|
||||
dtype=torch_dtype,
|
||||
device_map=device_map,
|
||||
hf_quantizer=hf_quantizer,
|
||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||
unexpected_keys=unexpected_keys,
|
||||
)
|
||||
else:
|
||||
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
||||
|
||||
if model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
if hf_quantizer is not None:
|
||||
hf_quantizer.postprocess_model(model)
|
||||
model.hf_quantizer = hf_quantizer
|
||||
|
||||
if torch_dtype is not None and hf_quantizer is None:
|
||||
model.to(torch_dtype)
|
||||
|
||||
model.eval()
|
||||
|
||||
if device_map is not None:
|
||||
device_map_kwargs = {"device_map": device_map}
|
||||
dispatch_model(model, **device_map_kwargs)
|
||||
|
||||
return model
|
||||
class FromOriginalModelMixin(FromOriginalModelMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FromOriginalModelMixin` from diffusers.loaders.single_file_model has been deprecated. Please use `from diffusers.loaders.single_file.single_file_model import FromOriginalModelMixin` instead."
|
||||
deprecate("diffusers.loaders.single_file_model.FromOriginalModelMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,170 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
|
||||
from ..models.embeddings import (
|
||||
ImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils import deprecate
|
||||
from .ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
pass
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class FluxTransformer2DLoadersMixin:
|
||||
"""
|
||||
Load layers into a [`FluxTransformer2DModel`].
|
||||
"""
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
updated_state_dict = {}
|
||||
image_projection = None
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
if "proj.weight" in state_dict:
|
||||
# IP-Adapter
|
||||
num_image_text_embeds = 4
|
||||
if state_dict["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embeds = 16
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
cross_attention_dim = state_dict["proj.weight"].shape[0] // num_image_text_embeds
|
||||
|
||||
with init_context():
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
image_embed_dim=clip_embeddings_dim,
|
||||
num_image_text_embeds=num_image_text_embeds,
|
||||
)
|
||||
|
||||
for key, value in state_dict.items():
|
||||
diffusers_name = key.replace("proj", "image_embeds")
|
||||
updated_state_dict[diffusers_name] = value
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_projection.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
FluxIPAdapterJointAttnProcessor2_0,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# set ip-adapter cross-attention processors & load state_dict
|
||||
attn_procs = {}
|
||||
key_id = 0
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for name in self.attn_processors.keys():
|
||||
if name.startswith("single_transformer_blocks"):
|
||||
attn_processor_class = self.attn_processors[name].__class__
|
||||
attn_procs[name] = attn_processor_class()
|
||||
else:
|
||||
cross_attention_dim = self.config.joint_attention_dim
|
||||
hidden_size = self.inner_dim
|
||||
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
|
||||
num_image_text_embeds = []
|
||||
for state_dict in state_dicts:
|
||||
if "proj.weight" in state_dict["image_proj"]:
|
||||
num_image_text_embed = 4
|
||||
if state_dict["image_proj"]["proj.weight"].shape[0] == 65536:
|
||||
num_image_text_embed = 16
|
||||
# IP-Adapter
|
||||
num_image_text_embeds += [num_image_text_embed]
|
||||
|
||||
with init_context():
|
||||
attn_procs[name] = attn_processor_class(
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
num_tokens=num_image_text_embeds,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
value_dict = {}
|
||||
for i, state_dict in enumerate(state_dicts):
|
||||
value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]})
|
||||
value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]})
|
||||
value_dict.update({f"to_k_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_k_ip.bias"]})
|
||||
value_dict.update({f"to_v_ip.{i}.bias": state_dict["ip_adapter"][f"{key_id}.to_v_ip.bias"]})
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(value_dict)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
dtype = self.dtype
|
||||
load_model_dict_into_meta(attn_procs[name], value_dict, device_map=device_map, dtype=dtype)
|
||||
|
||||
key_id += 1
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
if not isinstance(state_dicts, list):
|
||||
state_dicts = [state_dicts]
|
||||
|
||||
self.encoder_hid_proj = None
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
image_projection_layers = []
|
||||
for state_dict in state_dicts:
|
||||
image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers(
|
||||
state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage
|
||||
)
|
||||
image_projection_layers.append(image_projection_layer)
|
||||
|
||||
self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers)
|
||||
self.config.encoder_hid_dim_type = "ip_image_proj"
|
||||
class FluxTransformer2DLoadersMixin(FluxTransformer2DLoadersMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxTransformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_flux import FluxTransformer2DLoadersMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.FluxTransformer2DLoadersMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@@ -11,160 +11,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from contextlib import nullcontext
|
||||
from typing import Dict
|
||||
|
||||
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
||||
from ..models.embeddings import IPAdapterTimeImageProjection
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
|
||||
from ..utils import is_accelerate_available, is_torch_version, logging
|
||||
from ..utils import deprecate
|
||||
from .ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class SD3Transformer2DLoadersMixin:
|
||||
"""Load IP-Adapters and LoRA layers into a `[SD3Transformer2DModel]`."""
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> Dict:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
# IP-Adapter cross attention parameters
|
||||
hidden_size = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
ip_hidden_states_dim = self.config.attention_head_dim * self.config.num_attention_heads
|
||||
timesteps_emb_dim = state_dict["0.norm_ip.linear.weight"].shape[1]
|
||||
|
||||
# Dict where key is transformer layer index, value is attention processor's state dict
|
||||
# ip_adapter state dict keys example: "0.norm_ip.linear.weight"
|
||||
layer_state_dict = {idx: {} for idx in range(len(self.attn_processors))}
|
||||
for key, weights in state_dict.items():
|
||||
idx, name = key.split(".", maxsplit=1)
|
||||
layer_state_dict[int(idx)][name] = weights
|
||||
|
||||
# Create IP-Adapter attention processor & load state_dict
|
||||
attn_procs = {}
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
for idx, name in enumerate(self.attn_processors.keys()):
|
||||
with init_context():
|
||||
attn_procs[name] = SD3IPAdapterJointAttnProcessor2_0(
|
||||
hidden_size=hidden_size,
|
||||
ip_hidden_states_dim=ip_hidden_states_dim,
|
||||
head_dim=self.config.attention_head_dim,
|
||||
timesteps_emb_dim=timesteps_emb_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
attn_procs[name].load_state_dict(layer_state_dict[idx], strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(
|
||||
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
||||
)
|
||||
|
||||
return attn_procs
|
||||
|
||||
def _convert_ip_adapter_image_proj_to_diffusers(
|
||||
self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
) -> IPAdapterTimeImageProjection:
|
||||
if low_cpu_mem_usage:
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
else:
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
init_context = init_empty_weights if low_cpu_mem_usage else nullcontext
|
||||
|
||||
# Convert to diffusers
|
||||
updated_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# InstantX/SD3.5-Large-IP-Adapter
|
||||
if key.startswith("layers."):
|
||||
idx = key.split(".")[1]
|
||||
key = key.replace(f"layers.{idx}.0.norm1", f"layers.{idx}.ln0")
|
||||
key = key.replace(f"layers.{idx}.0.norm2", f"layers.{idx}.ln1")
|
||||
key = key.replace(f"layers.{idx}.0.to_q", f"layers.{idx}.attn.to_q")
|
||||
key = key.replace(f"layers.{idx}.0.to_kv", f"layers.{idx}.attn.to_kv")
|
||||
key = key.replace(f"layers.{idx}.0.to_out", f"layers.{idx}.attn.to_out.0")
|
||||
key = key.replace(f"layers.{idx}.1.0", f"layers.{idx}.adaln_norm")
|
||||
key = key.replace(f"layers.{idx}.1.1", f"layers.{idx}.ff.net.0.proj")
|
||||
key = key.replace(f"layers.{idx}.1.3", f"layers.{idx}.ff.net.2")
|
||||
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
||||
updated_state_dict[key] = value
|
||||
|
||||
# Image projetion parameters
|
||||
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
||||
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
||||
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
||||
heads = updated_state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
||||
num_queries = updated_state_dict["latents"].shape[1]
|
||||
timestep_in_dim = updated_state_dict["time_embedding.linear_1.weight"].shape[1]
|
||||
|
||||
# Image projection
|
||||
with init_context():
|
||||
image_proj = IPAdapterTimeImageProjection(
|
||||
embed_dim=embed_dim,
|
||||
output_dim=output_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
timestep_in_dim=timestep_in_dim,
|
||||
)
|
||||
|
||||
if not low_cpu_mem_usage:
|
||||
image_proj.load_state_dict(updated_state_dict, strict=True)
|
||||
else:
|
||||
device_map = {"": self.device}
|
||||
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
||||
|
||||
return image_proj
|
||||
|
||||
def _load_ip_adapter_weights(self, state_dict: Dict, low_cpu_mem_usage: bool = _LOW_CPU_MEM_USAGE_DEFAULT) -> None:
|
||||
"""Sets IP-Adapter attention processors, image projection, and loads state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (`Dict`):
|
||||
State dict with keys "ip_adapter", which contains parameters for attention processors, and
|
||||
"image_proj", which contains parameters for image projection net.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dict["ip_adapter"], low_cpu_mem_usage)
|
||||
self.set_attn_processor(attn_procs)
|
||||
|
||||
self.image_proj = self._convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"], low_cpu_mem_usage)
|
||||
class SD3Transformer2DLoadersMixin(SD3Transformer2DLoadersMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3Transformer2DLoadersMixin` from diffusers.loaders.ip_adapter has been deprecated. Please use `from diffusers.loaders.ip_adapter.transformer_sd3 import SD3Transformer2DLoadersMixin` instead."
|
||||
deprecate("diffusers.loaders.ip_adapter.SD3Transformer2DLoadersMixin", "0.36", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
5
src/diffusers/loaders/unet/__init__.py
Normal file
5
src/diffusers/loaders/unet/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from ...utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .unet import UNet2DConditionLoadersMixin
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
|
||||
from ..models.embeddings import (
|
||||
from ...models.embeddings import (
|
||||
ImageProjection,
|
||||
IPAdapterFaceIDImageProjection,
|
||||
IPAdapterFaceIDPlusImageProjection,
|
||||
@@ -30,8 +30,8 @@ from ..models.embeddings import (
|
||||
IPAdapterPlusImageProjection,
|
||||
MultiIPAdapterImageProjection,
|
||||
)
|
||||
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ..utils import (
|
||||
from ...models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
_get_model_file,
|
||||
convert_unet_state_dict_to_peft,
|
||||
@@ -43,9 +43,9 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from .lora_base import _func_optionally_disable_offloading
|
||||
from .lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from .utils import AttnProcsLayers
|
||||
from ..lora.lora_base import _func_optionally_disable_offloading
|
||||
from ..lora.lora_pipeline import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
||||
from ..utils import AttnProcsLayers
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -247,7 +247,7 @@ class UNet2DConditionLoadersMixin:
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
||||
from ...models.attention_processor import CustomDiffusionAttnProcessor
|
||||
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -395,7 +395,7 @@ class UNet2DConditionLoadersMixin:
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
# Copied from diffusers.loaders.lora.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
def _optionally_disable_offloading(cls, _pipeline):
|
||||
"""
|
||||
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
||||
@@ -451,7 +451,7 @@ class UNet2DConditionLoadersMixin:
|
||||
pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
```
|
||||
"""
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
@@ -513,7 +513,7 @@ class UNet2DConditionLoadersMixin:
|
||||
logger.info(f"Model weights saved in {save_path}")
|
||||
|
||||
def _get_custom_diffusion_state_dict(self):
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionAttnProcessor2_0,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
@@ -759,7 +759,7 @@ class UNet2DConditionLoadersMixin:
|
||||
return image_projection
|
||||
|
||||
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
||||
from ..models.attention_processor import (
|
||||
from ...models.attention_processor import (
|
||||
IPAdapterAttnProcessor,
|
||||
IPAdapterAttnProcessor2_0,
|
||||
IPAdapterXFormersAttnProcessor,
|
||||
@@ -14,12 +14,12 @@
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Dict, List, Union
|
||||
|
||||
from ..utils import logging
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# import here to avoid circular imports
|
||||
from ..models import UNet2DConditionModel
|
||||
from ...models import UNet2DConditionModel
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -17,8 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import deprecate
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..attention_processor import (
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
|
||||
@@ -19,7 +19,7 @@ from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import BaseOutput, logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
|
||||
@@ -17,7 +17,7 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin
|
||||
from ...utils import logging
|
||||
from ..attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
|
||||
@@ -20,8 +20,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import LuminaFeedForward
|
||||
from ..attention_processor import Attention
|
||||
|
||||
@@ -19,8 +19,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
|
||||
@@ -19,8 +19,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin, UNet2DConditionLoadersMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..activations import get_activation
|
||||
from ..attention_processor import (
|
||||
|
||||
@@ -352,7 +352,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
@@ -403,7 +403,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
original_out = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
@@ -486,7 +486,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -541,7 +541,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -590,7 +590,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipe = pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
out_features, in_features = pipe.transformer.x_embedder.weight.shape
|
||||
@@ -653,7 +653,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"transformer.x_embedder.lora_B.weight": normal_lora_B.weight,
|
||||
}
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.load_lora_weights(lora_state_dict, "adapter-1")
|
||||
@@ -668,7 +668,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_unload_with_parameter_expanded_shapes(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
@@ -734,7 +734,7 @@ class FluxControlLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
def test_lora_unload_with_parameter_expanded_shapes_and_no_reset(self):
|
||||
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Change the transformer config to mimic a real use case.
|
||||
|
||||
@@ -1017,7 +1017,7 @@ class PeftLoraLoaderMixinTests:
|
||||
)
|
||||
|
||||
scale_with_wrong_components = {"foo": 0.0, "bar": 0.0, "tik": 0.0}
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||
logger.setLevel(30)
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
pipe.set_adapters(adapter_name, adapter_weights=scale_with_wrong_components)
|
||||
@@ -1824,7 +1824,7 @@ class PeftLoraLoaderMixinTests:
|
||||
elif lora_module == "text_encoder_2":
|
||||
prefix = "text_encoder_2"
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_base")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_base")
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
with CaptureLogger(logger) as cap_logger:
|
||||
@@ -1925,7 +1925,7 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
_, _, inputs = self.get_dummy_inputs(with_generator=False)
|
||||
|
||||
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
|
||||
logger = logging.get_logger("diffusers.loaders.lora.lora_pipeline")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
original_output = pipe(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@@ -5,7 +5,7 @@ import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils.testing_utils import (
|
||||
numpy_cosine_similarity_distance,
|
||||
|
||||
@@ -18,9 +18,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
)
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -16,9 +16,7 @@
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
)
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -18,9 +18,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
WanTransformer3DModel,
|
||||
)
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -3,9 +3,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
SanaTransformer2DModel,
|
||||
)
|
||||
from diffusers import SanaTransformer2DModel
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -8,7 +8,7 @@ from diffusers import (
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
T2IAdapter,
|
||||
)
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -5,7 +5,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline
|
||||
from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.loaders.single_file.single_file_utils import _extract_repo_id_and_weights_name
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.testing_utils import (
|
||||
backend_empty_cache,
|
||||
|
||||
@@ -98,7 +98,7 @@ if __name__ == "__main__":
|
||||
},
|
||||
"LoRA Mixins": {
|
||||
"doc_path": "docs/source/en/api/loaders/lora.md",
|
||||
"src_path": "src/diffusers/loaders/lora_pipeline.py",
|
||||
"src_path": "src/diffusers/loaders/lora/lora_pipeline.py",
|
||||
"doc_regex": r"\[\[autodoc\]\]\s([^\n]+)",
|
||||
"src_regex": r"class\s+(\w+LoraLoaderMixin(?:\d*_?\d*))[:(]",
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user