Compare commits

...

18 Commits

Author SHA1 Message Date
Patrick von Platen
f5942649f5 Release: v.0.21.4-patch 2023-09-29 15:39:55 +00:00
Patrick von Platen
edea57749e [Lora] fix lora fuse unfuse (#5003)
* fix lora fuse unfuse

* add same changes to loaders.py

* add test

---------

Co-authored-by: multimodalart <joaopaulo.passos+multimodal@gmail.com>
2023-09-29 15:30:16 +00:00
Sayak Paul
c37c840b1b update 2023-09-27 22:13:19 +05:30
Sayak Paul
9858053bfe Release: v0.21.3 2023-09-27 22:10:50 +05:30
Patrick von Platen
6a3301fe34 resolve conflicts. 2023-09-27 22:09:04 +05:30
Patrick von Platen
813a1b2ee0 Fix one more 2023-09-19 00:09:30 +02:00
Patrick von Platen
a43b8574a9 Patch release: v0.21.2 2023-09-19 00:06:16 +02:00
Sayak Paul
a2f0db52e3 [LoRA] don't break offloading for incompatible lora ckpts. (#5085)
* don't break offloading for incompatible lora ckpts.

* debugging

* better condition.

* fix

* fix

* fix

* fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-19 00:05:34 +02:00
Will Berman
92f6693b37 remove unused adapter weights in constructor (#5088)
remove adapter weights in MultiAdapter constructor
2023-09-19 00:05:28 +02:00
Will Berman
932897afa8 t2i Adapter community member fix (#5090)
* convert tensorrt controlnet

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix number controlnet condition

* Add convert SD XL to onnx

* Add convert SD XL to tensorrt

* Add convert SD XL to tensorrt

* Add examples in comments

* Add examples in comments

* Add test onnx controlnet

* Add tensorrt test

* Remove copied

* Move file test to examples/community

* Remove script

* Remove script

* Remove text

* Fix import

* Fix T2I MultiAdapter

* fix tests

---------

Co-authored-by: dotieuthien <thien.do@mservice.com.vn>
Co-authored-by: dotieuthien <dotieuthien9997@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: dotieuthien <hades@cinnamon.is>
2023-09-19 00:05:20 +02:00
Patrick von Platen
c2940434d0 [Textual inversion] Refactor textual inversion to make it cleaner (#5076)
* [Textual inversion] Clean loading

* [Textual inversion] Clean loading

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up
2023-09-19 00:04:53 +02:00
Patrick von Platen
60ab8fad16 Patch release: v0.21.1 2023-09-14 13:06:57 +02:00
Patrick von Platen
d17240457f [Import] Add missing settings / Correct some dummy imports (#5036)
* [Import] Add missing settings

* up

* up

* up
2023-09-14 12:47:55 +02:00
Vladimir Mandic
7512fc4df5 allow loading of sd models from safetensors without online lookups using local config files (#5019)
finish config_files implementation
2023-09-14 12:47:41 +02:00
Patrick von Platen
0c2f1ccc97 [Import] Don't force transformers to be installed (#5035)
* [Import] Don't force transformers to be installed

* make style
2023-09-14 12:47:34 +02:00
Dhruv Nair
47f2d2c7be Fix model offload bug when key isn't present (#5030)
* fix model offload bug when key isn't present

* make style
2023-09-14 12:47:25 +02:00
Patrick von Platen
af85591593 Patch release: v0.21.1 2023-09-14 12:46:39 +02:00
Patrick von Platen
29f15673ed Release: v0.21.0 2023-09-13 15:58:24 +02:00
42 changed files with 618 additions and 299 deletions

View File

@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
# Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)

View File

@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = logging.getLogger(__name__)

View File

@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0")
check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO")

View File

@@ -154,6 +154,7 @@ if __name__ == "__main__":
pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size,
prediction_type=args.prediction_type,
model_type=args.pipeline_type,

View File

@@ -244,7 +244,7 @@ install_requires = [
setup(
name="diffusers",
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.21.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.21.0.dev0"
__version__ = "0.21.4"
from typing import TYPE_CHECKING

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import os
import re
import warnings
from collections import defaultdict
from contextlib import nullcontext
from io import BytesIO
@@ -41,7 +40,7 @@ from .utils.import_utils import BACKENDS_MAPPING
if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -120,7 +119,7 @@ class PatchedLoraProjection(nn.Module):
self.lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.regular_linear_layer.weight.data
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None
allow_pickle = False
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else:
lora.load_state_dict(value_dict)
elif is_custom_diffusion:
attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict)
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim,
)
attn_processors[key].load_state_dict(value_dict)
self.set_attn_processor(attn_processors)
else:
raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
)
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers
for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -622,12 +650,87 @@ class UNet2DConditionLoadersMixin:
module._unfuse_lora()
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 3.1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
state_dicts.append(state_dict)
return state_dicts
class TextualInversionLoaderMixin:
r"""
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
"""
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
@@ -654,7 +757,7 @@ class TextualInversionLoaderMixin:
return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
@@ -684,12 +787,103 @@ class TextualInversionLoaderMixin:
return prompt
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
if tokenizer is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if text_encoder is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
@staticmethod
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
all_tokens = []
all_embeddings = []
for state_dict, token in zip(state_dicts, tokens):
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
loaded_token = token
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
else:
raise ValueError(
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
" input key."
)
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
if token in tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
all_tokens.append(token)
all_embeddings.append(embedding)
return all_tokens, all_embeddings
@staticmethod
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
all_tokens = []
all_embeddings = []
for embedding, token in zip(embeddings, tokens):
if f"{token}_1" in tokenizer.get_vocab():
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
all_embeddings += [e for e in embedding] # noqa: C416
else:
all_tokens += [token]
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
return all_tokens, all_embeddings
def load_textual_inversion(
self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None,
text_encoder: Optional[PreTrainedModel] = None,
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs,
):
r"""
@@ -789,25 +983,44 @@ class TextualInversionLoaderMixin:
```
"""
# 1. Set correct tokenizer and text encoder
tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None)
if tokenizer is None:
# 2. Normalize inputs
pretrained_model_name_or_paths = (
[pretrained_model_name_or_path]
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
# 3. Check inputs
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
# 4. Load state dicts of textual embeddings
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
# 4. Retrieve tokens and embeddings
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
# 5. Extend tokens and embeddings for multi vector
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
# 6. Make sure all embeddings have the correct size
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
)
if text_encoder is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
# 7. Now we can be sure that loading the embedding matrix works
# < Unsafe code:
# Remove any existing hooks.
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
@@ -816,168 +1029,34 @@ class TextualInversionLoaderMixin:
logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
)
recursive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recursive)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
# 7.2 save expected device and dtype
device = text_encoder.device
dtype = text_encoder.dtype
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
if isinstance(token, str):
tokens = [token]
elif token is None:
tokens = [None] * len(pretrained_model_name_or_paths)
else:
tokens = token
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
# 2. Load token and embedding correcly from file
loaded_token = None
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
# 7.3 Increase token embedding matrix
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
input_embeddings = text_encoder.get_input_embeddings().weight
# 7.4 Load token and embedding
for token, embedding in zip(tokens, embeddings):
# add tokens and get ids
tokenizer.add_tokens(tokens)
token_ids = tokenizer.convert_tokens_to_ids(tokens)
token_ids_and_embeddings += zip(token_ids, embeddings)
tokenizer.add_tokens(token)
token_id = tokenizer.convert_tokens_to_ids(token)
input_embeddings.data[token_id] = embedding
logger.info(f"Loaded textual inversion embedding for {token}.")
# resize token embeddings and set all new embeddings
text_encoder.resize_token_embeddings(len(tokenizer))
for token_id, embedding in token_ids_and_embeddings:
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
input_embeddings.to(dtype=dtype, device=device)
# offload back
# 7.5 Offload the model again
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
# / Unsafe Code >
class LoraLoaderMixin:
r"""
@@ -1009,26 +1088,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# Remove any existing hooks.
is_model_cpu_offload = False
is_sequential_cpu_offload = False
recurive = False
for _, component in self.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
state_dict,
network_alphas=network_alphas,
unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
)
self.load_lora_into_text_encoder(
state_dict,
@@ -1036,14 +1110,9 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder,
lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
)
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod
def lora_state_dict(
cls,
@@ -1340,7 +1409,7 @@ class LoraLoaderMixin:
return new_state_dict
@classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
"""
This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1382,13 +1451,22 @@ class LoraLoaderMixin:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)
logger.warn(warn_message)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod
def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1498,11 +1576,15 @@ class LoraLoaderMixin:
low_cpu_mem_usage=low_cpu_mem_usage,
)
# set correct dtype & device
text_encoder_lora_state_dict = {
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
for k, v in text_encoder_lora_state_dict.items()
}
is_pipeline_offloaded = _pipeline is not None and any(
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
)
if is_pipeline_offloaded and low_cpu_mem_usage:
low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1518,8 +1600,33 @@ class LoraLoaderMixin:
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
)
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property
def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2098,6 +2205,7 @@ class FromSingleFileMixin:
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False)
@@ -2215,6 +2323,7 @@ class FromSingleFileMixin:
vae=vae,
tokenizer=tokenizer,
original_config_file=original_config_file,
config_files=config_files,
)
if torch_dtype is not None:
@@ -2556,3 +2665,132 @@ class FromOriginalControlnetMixin:
controlnet.to(torch_dtype=torch_dtype)
return controlnet
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
`self.unet`.
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
_pipeline=self,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
_pipeline=self,
)
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -90,6 +90,8 @@ class MultiAdapter(ModelMixin):
features = adapter(x)
if accume_state is None:
accume_state = features
for i in range(len(accume_state)):
accume_state[i] = w * accume_state[i]
else:
for i in range(len(features)):
accume_state[i] += w * features[i]

View File

@@ -304,19 +304,16 @@ class Attention(nn.Module):
self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"):
if (
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
deprecate(
"set_processor to offload LoRA",
"0.26.0",
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
)
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules():
if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None)

View File

@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:

View File

@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):

View File

@@ -139,7 +139,7 @@ class LoRACompatibleConv(nn.Conv2d):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data
@@ -204,7 +204,7 @@ class LoRACompatibleLinear(nn.Linear):
self._lora_scale = lora_scale
def _unfuse_lora(self):
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
return
fused_weight = self.weight.data

View File

@@ -191,7 +191,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -215,9 +217,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -239,7 +241,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def forward(
self,

View File

@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):

View File

@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size):

View File

@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self._all_hooks = []
hook = None
for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str)
model = all_model_components.pop(model_str, None)
if not isinstance(model, torch.nn.Module):
continue

View File

@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None
# model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000:
# v2.1 needs to upcast attention
upcast_attention = True
elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
original_config_file = BytesIO(requests.get(config_url).content)
if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file)

View File

@@ -50,13 +50,26 @@ class SafetyConfig(object):
_dummy_objects = {}
_additional_imports = {}
_import_structure = {
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_import_structure = {}
_additional_imports.update({"SafetyConfig": SafetyConfig})
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)
if TYPE_CHECKING:
try:
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
from .safety_checker import SafeStableDiffusionSafetyChecker
else:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
import sys
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -162,7 +162,6 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
adapter_weights: Optional[List[float]] = None,
requires_safety_checker: bool = True,
):
super().__init__()
@@ -184,7 +183,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
)
if isinstance(adapter, (list, tuple)):
adapter = MultiAdapter(adapter, adapter_weights=adapter_weights)
adapter = MultiAdapter(adapter)
self.register_modules(
vae=vae,
@@ -727,9 +726,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if isinstance(self.adapter, MultiAdapter):
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)

View File

@@ -47,3 +47,5 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r"""
Sets the attention processor to use to compute attention.
@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
module.set_processor(processor, _remove_lora=_remove_lora)
else:
module.set_processor(processor.pop(f"{name}.processor"))
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -51,3 +51,6 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -41,7 +41,6 @@ if TYPE_CHECKING:
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
else:
import sys
@@ -51,3 +50,6 @@ else:
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -43,7 +43,7 @@ from diffusers.models.attention_processor import (
LoRAAttnProcessor2_0,
XFormersAttnProcessor,
)
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
def create_unet_lora_layers(unet: nn.Module):
@@ -1464,3 +1464,41 @@ class LoraIntegrationTests(unittest.TestCase):
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
self.assertTrue(np.allclose(images, expected, atol=1e-3))
@nightly
def test_sequential_fuse_unfuse(self):
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
# 1. round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice = images[0, -3:, -3:, -1].flatten()
pipe.unfuse_lora()
# 2. round
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
pipe.fuse_lora()
pipe.unfuse_lora()
# 3. round
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
pipe.fuse_lora()
pipe.unfuse_lora()
# 4. back to 1st round
pipe.load_lora_weights("Pclanglais/TintinIA")
pipe.fuse_lora()
generator = torch.Generator().manual_seed(0)
images_2 = pipe(
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
).images
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))

View File

@@ -216,7 +216,9 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
return super().get_dummy_components("multi_adapter")
def get_dummy_inputs(self, device, seed=0):
return super().get_dummy_inputs(device, seed, num_images=2)
inputs = super().get_dummy_inputs(device, seed, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs
def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator