mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
18 Commits
add_svd_in
...
v0.21.4-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f5942649f5 | ||
|
|
edea57749e | ||
|
|
c37c840b1b | ||
|
|
9858053bfe | ||
|
|
6a3301fe34 | ||
|
|
813a1b2ee0 | ||
|
|
a43b8574a9 | ||
|
|
a2f0db52e3 | ||
|
|
92f6693b37 | ||
|
|
932897afa8 | ||
|
|
c2940434d0 | ||
|
|
60ab8fad16 | ||
|
|
d17240457f | ||
|
|
7512fc4df5 | ||
|
|
0c2f1ccc97 | ||
|
|
47f2d2c7be | ||
|
|
af85591593 | ||
|
|
29f15673ed |
@@ -56,7 +56,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
# Cache compiled models across invocations of this script.
|
||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||
|
||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +79,7 @@ else:
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ else:
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.21.0.dev0")
|
||||
check_min_version("0.21.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
|
||||
@@ -154,6 +154,7 @@ if __name__ == "__main__":
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path_or_dict=args.checkpoint_path,
|
||||
original_config_file=args.original_config_file,
|
||||
config_files=args.config_files,
|
||||
image_size=args.image_size,
|
||||
prediction_type=args.prediction_type,
|
||||
model_type=args.pipeline_type,
|
||||
|
||||
2
setup.py
2
setup.py
@@ -244,7 +244,7 @@ install_requires = [
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.21.4", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
__version__ = "0.21.0.dev0"
|
||||
__version__ = "0.21.4"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
@@ -41,7 +40,7 @@ from .utils.import_utils import BACKENDS_MAPPING
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -120,7 +119,7 @@ class PatchedLoraProjection(nn.Module):
|
||||
self.lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.regular_linear_layer.weight.data
|
||||
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
|
||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||
network_alphas = kwargs.pop("network_alphas", None)
|
||||
|
||||
_pipeline = kwargs.pop("_pipeline", None)
|
||||
|
||||
is_network_alphas_none = network_alphas is None
|
||||
|
||||
allow_pickle = False
|
||||
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
|
||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
lora.load_state_dict(value_dict)
|
||||
|
||||
elif is_custom_diffusion:
|
||||
attn_processors = {}
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
self.set_attn_processor(attn_processors)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
if is_custom_diffusion:
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
# set lora layers
|
||||
for target_module, lora_layer in lora_layers_list:
|
||||
target_module.set_lora_layer(lora_layer)
|
||||
|
||||
self.to(dtype=self.dtype, device=self.device)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||
is_new_lora_format = all(
|
||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||
@@ -622,12 +650,87 @@ class UNet2DConditionLoadersMixin:
|
||||
module._unfuse_lora()
|
||||
|
||||
|
||||
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "text_inversion",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
|
||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||
# 3.1. Load textual inversion file
|
||||
model_file = None
|
||||
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
return state_dicts
|
||||
|
||||
|
||||
class TextualInversionLoaderMixin:
|
||||
r"""
|
||||
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
||||
"""
|
||||
|
||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
r"""
|
||||
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
||||
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||
@@ -654,7 +757,7 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
return prompts
|
||||
|
||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||
r"""
|
||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
||||
@@ -684,12 +787,103 @@ class TextualInversionLoaderMixin:
|
||||
|
||||
return prompt
|
||||
|
||||
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
|
||||
if tokenizer is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
|
||||
if len(pretrained_model_name_or_paths) != len(tokens):
|
||||
raise ValueError(
|
||||
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
|
||||
f"Make sure both lists have the same length."
|
||||
)
|
||||
|
||||
valid_tokens = [t for t in tokens if t is not None]
|
||||
if len(set(valid_tokens)) < len(valid_tokens):
|
||||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
||||
|
||||
@staticmethod
|
||||
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
|
||||
all_tokens = []
|
||||
all_embeddings = []
|
||||
for state_dict, token in zip(state_dicts, tokens):
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||
)
|
||||
loaded_token = token
|
||||
embedding = state_dict
|
||||
elif len(state_dict) == 1:
|
||||
# diffusers
|
||||
loaded_token, embedding = next(iter(state_dict.items()))
|
||||
elif "string_to_param" in state_dict:
|
||||
# A1111
|
||||
loaded_token = state_dict["name"]
|
||||
embedding = state_dict["string_to_param"]["*"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
|
||||
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
|
||||
" input key."
|
||||
)
|
||||
|
||||
if token is not None and loaded_token != token:
|
||||
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
if token in tokenizer.get_vocab():
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
all_tokens.append(token)
|
||||
all_embeddings.append(embedding)
|
||||
|
||||
return all_tokens, all_embeddings
|
||||
|
||||
@staticmethod
|
||||
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
|
||||
all_tokens = []
|
||||
all_embeddings = []
|
||||
|
||||
for embedding, token in zip(embeddings, tokens):
|
||||
if f"{token}_1" in tokenizer.get_vocab():
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
raise ValueError(
|
||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||
if is_multi_vector:
|
||||
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||
all_embeddings += [e for e in embedding] # noqa: C416
|
||||
else:
|
||||
all_tokens += [token]
|
||||
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||
|
||||
return all_tokens, all_embeddings
|
||||
|
||||
def load_textual_inversion(
|
||||
self,
|
||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||
token: Optional[Union[str, List[str]]] = None,
|
||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
||||
text_encoder: Optional[PreTrainedModel] = None,
|
||||
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
|
||||
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
@@ -789,25 +983,44 @@ class TextualInversionLoaderMixin:
|
||||
```
|
||||
|
||||
"""
|
||||
# 1. Set correct tokenizer and text encoder
|
||||
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||
|
||||
if tokenizer is None:
|
||||
# 2. Normalize inputs
|
||||
pretrained_model_name_or_paths = (
|
||||
[pretrained_model_name_or_path]
|
||||
if not isinstance(pretrained_model_name_or_path, list)
|
||||
else pretrained_model_name_or_path
|
||||
)
|
||||
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
|
||||
|
||||
# 3. Check inputs
|
||||
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
|
||||
|
||||
# 4. Load state dicts of textual embeddings
|
||||
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||
|
||||
# 4. Retrieve tokens and embeddings
|
||||
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
|
||||
|
||||
# 5. Extend tokens and embeddings for multi vector
|
||||
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
|
||||
|
||||
# 6. Make sure all embeddings have the correct size
|
||||
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
|
||||
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
|
||||
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
|
||||
)
|
||||
|
||||
if text_encoder is None:
|
||||
raise ValueError(
|
||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||
f" `{self.load_textual_inversion.__name__}`"
|
||||
)
|
||||
# 7. Now we can be sure that loading the embedding matrix works
|
||||
# < Unsafe code:
|
||||
|
||||
# Remove any existing hooks.
|
||||
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
recursive = False
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
@@ -816,168 +1029,34 @@ class TextualInversionLoaderMixin:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
recursive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recursive)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
# 7.2 save expected device and dtype
|
||||
device = text_encoder.device
|
||||
dtype = text_encoder.dtype
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {
|
||||
"file_type": "text_inversion",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path, list):
|
||||
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
||||
else:
|
||||
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
||||
|
||||
if isinstance(token, str):
|
||||
tokens = [token]
|
||||
elif token is None:
|
||||
tokens = [None] * len(pretrained_model_name_or_paths)
|
||||
else:
|
||||
tokens = token
|
||||
|
||||
if len(pretrained_model_name_or_paths) != len(tokens):
|
||||
raise ValueError(
|
||||
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
||||
f"Make sure both lists have the same length."
|
||||
)
|
||||
|
||||
valid_tokens = [t for t in tokens if t is not None]
|
||||
if len(set(valid_tokens)) < len(valid_tokens):
|
||||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
||||
|
||||
token_ids_and_embeddings = []
|
||||
|
||||
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||
# 1. Load textual inversion file
|
||||
model_file = None
|
||||
# Let's first try to load .safetensors weights
|
||||
if (use_safetensors and weight_name is None) or (
|
||||
weight_name is not None and weight_name.endswith(".safetensors")
|
||||
):
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||
except Exception as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
|
||||
model_file = None
|
||||
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path
|
||||
|
||||
# 2. Load token and embedding correcly from file
|
||||
loaded_token = None
|
||||
if isinstance(state_dict, torch.Tensor):
|
||||
if token is None:
|
||||
raise ValueError(
|
||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||
)
|
||||
embedding = state_dict
|
||||
elif len(state_dict) == 1:
|
||||
# diffusers
|
||||
loaded_token, embedding = next(iter(state_dict.items()))
|
||||
elif "string_to_param" in state_dict:
|
||||
# A1111
|
||||
loaded_token = state_dict["name"]
|
||||
embedding = state_dict["string_to_param"]["*"]
|
||||
|
||||
if token is not None and loaded_token != token:
|
||||
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||
else:
|
||||
token = loaded_token
|
||||
|
||||
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
|
||||
|
||||
# 3. Make sure we don't mess up the tokenizer or text encoder
|
||||
vocab = tokenizer.get_vocab()
|
||||
if token in vocab:
|
||||
raise ValueError(
|
||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
elif f"{token}_1" in vocab:
|
||||
multi_vector_tokens = [token]
|
||||
i = 1
|
||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||
multi_vector_tokens.append(f"{token}_{i}")
|
||||
i += 1
|
||||
|
||||
raise ValueError(
|
||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||
)
|
||||
|
||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||
|
||||
if is_multi_vector:
|
||||
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||
embeddings = [e for e in embedding] # noqa: C416
|
||||
else:
|
||||
tokens = [token]
|
||||
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||
# 7.3 Increase token embedding matrix
|
||||
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
|
||||
input_embeddings = text_encoder.get_input_embeddings().weight
|
||||
|
||||
# 7.4 Load token and embedding
|
||||
for token, embedding in zip(tokens, embeddings):
|
||||
# add tokens and get ids
|
||||
tokenizer.add_tokens(tokens)
|
||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
token_ids_and_embeddings += zip(token_ids, embeddings)
|
||||
|
||||
tokenizer.add_tokens(token)
|
||||
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||
input_embeddings.data[token_id] = embedding
|
||||
logger.info(f"Loaded textual inversion embedding for {token}.")
|
||||
|
||||
# resize token embeddings and set all new embeddings
|
||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
for token_id, embedding in token_ids_and_embeddings:
|
||||
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
||||
input_embeddings.to(dtype=dtype, device=device)
|
||||
|
||||
# offload back
|
||||
# 7.5 Offload the model again
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
|
||||
# / Unsafe Code >
|
||||
|
||||
|
||||
class LoraLoaderMixin:
|
||||
r"""
|
||||
@@ -1009,26 +1088,21 @@ class LoraLoaderMixin:
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
# Remove any existing hooks.
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
recurive = False
|
||||
for _, component in self.components.items():
|
||||
if isinstance(component, nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
recurive = is_sequential_cpu_offload
|
||||
remove_hook_from_module(component, recurse=recurive)
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(
|
||||
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_pipeline=self,
|
||||
)
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
@@ -1036,14 +1110,9 @@ class LoraLoaderMixin:
|
||||
text_encoder=self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
self.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
self.enable_sequential_cpu_offload()
|
||||
|
||||
@classmethod
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
@@ -1340,7 +1409,7 @@ class LoraLoaderMixin:
|
||||
return new_state_dict
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
|
||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||
|
||||
@@ -1382,13 +1451,22 @@ class LoraLoaderMixin:
|
||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||
warnings.warn(warn_message)
|
||||
logger.warn(warn_message)
|
||||
|
||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
unet.load_attn_procs(
|
||||
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(
|
||||
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
|
||||
cls,
|
||||
state_dict,
|
||||
network_alphas,
|
||||
text_encoder,
|
||||
prefix=None,
|
||||
lora_scale=1.0,
|
||||
low_cpu_mem_usage=None,
|
||||
_pipeline=None,
|
||||
):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
@@ -1498,11 +1576,15 @@ class LoraLoaderMixin:
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
text_encoder_lora_state_dict = {
|
||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
for k, v in text_encoder_lora_state_dict.items()
|
||||
}
|
||||
is_pipeline_offloaded = _pipeline is not None and any(
|
||||
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
||||
)
|
||||
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||
low_cpu_mem_usage = True
|
||||
logger.info(
|
||||
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||
@@ -1518,8 +1600,33 @@ class LoraLoaderMixin:
|
||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||
# Now we remove any existing hooks to
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
if _pipeline is not None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
if hasattr(component, "_hf_hook"):
|
||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||
is_sequential_cpu_offload = isinstance(
|
||||
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||
)
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||
|
||||
# Offload back.
|
||||
if is_model_cpu_offload:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
# Unsafe code />
|
||||
|
||||
@property
|
||||
def lora_scale(self) -> float:
|
||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||
@@ -2098,6 +2205,7 @@ class FromSingleFileMixin:
|
||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||
|
||||
original_config_file = kwargs.pop("original_config_file", None)
|
||||
config_files = kwargs.pop("config_files", None)
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -2215,6 +2323,7 @@ class FromSingleFileMixin:
|
||||
vae=vae,
|
||||
tokenizer=tokenizer,
|
||||
original_config_file=original_config_file,
|
||||
config_files=config_files,
|
||||
)
|
||||
|
||||
if torch_dtype is not None:
|
||||
@@ -2556,3 +2665,132 @@ class FromOriginalControlnetMixin:
|
||||
controlnet.to(torch_dtype=torch_dtype)
|
||||
|
||||
return controlnet
|
||||
|
||||
|
||||
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
"""
|
||||
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||
`self.text_encoder`.
|
||||
|
||||
All kwargs are forwarded to `self.lora_state_dict`.
|
||||
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||
|
||||
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
|
||||
`self.unet`.
|
||||
|
||||
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
|
||||
into `self.text_encoder`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
kwargs (`dict`, *optional*):
|
||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||
"""
|
||||
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||
# pipeline.
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
state_dict, network_alphas = self.lora_state_dict(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
unet_config=self.unet.config,
|
||||
**kwargs,
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
_pipeline=self,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
):
|
||||
r"""
|
||||
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `unet`.
|
||||
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||
process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
"""
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||
raise ValueError(
|
||||
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||
)
|
||||
|
||||
if unet_lora_layers:
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -90,6 +90,8 @@ class MultiAdapter(ModelMixin):
|
||||
features = adapter(x)
|
||||
if accume_state is None:
|
||||
accume_state = features
|
||||
for i in range(len(accume_state)):
|
||||
accume_state[i] = w * accume_state[i]
|
||||
else:
|
||||
for i in range(len(features)):
|
||||
accume_state[i] += w * features[i]
|
||||
|
||||
@@ -304,19 +304,16 @@ class Attention(nn.Module):
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_processor(self, processor: "AttnProcessor"):
|
||||
if (
|
||||
hasattr(self, "processor")
|
||||
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
|
||||
and self.to_q.lora_layer is not None
|
||||
):
|
||||
def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
|
||||
if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
|
||||
deprecate(
|
||||
"set_processor to offload LoRA",
|
||||
"0.26.0",
|
||||
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
||||
"In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
|
||||
)
|
||||
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
|
||||
# We need to remove all LoRA layers
|
||||
# Don't forget to remove ALL `_remove_lora` from the codebase
|
||||
for module in self.modules():
|
||||
if hasattr(module, "set_lora_layer"):
|
||||
module.set_lora_layer(None)
|
||||
|
||||
@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
|
||||
|
||||
@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
|
||||
@@ -139,7 +139,7 @@ class LoRACompatibleConv(nn.Conv2d):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
@@ -204,7 +204,7 @@ class LoRACompatibleLinear(nn.Linear):
|
||||
self._lora_scale = lora_scale
|
||||
|
||||
def _unfuse_lora(self):
|
||||
if not (hasattr(self, "w_up") and hasattr(self, "w_down")):
|
||||
if not (getattr(self, "w_up", None) is not None and getattr(self, "w_down", None) is not None):
|
||||
return
|
||||
|
||||
fused_weight = self.weight.data
|
||||
|
||||
@@ -191,7 +191,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -215,9 +217,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -239,7 +241,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
|
||||
|
||||
@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
|
||||
def set_attention_slice(self, slice_size):
|
||||
|
||||
@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
self._all_hooks = []
|
||||
hook = None
|
||||
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||
model = all_model_components.pop(model_str)
|
||||
model = all_model_components.pop(model_str, None)
|
||||
if not isinstance(model, torch.nn.Module):
|
||||
continue
|
||||
|
||||
|
||||
@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||
config_url = None
|
||||
|
||||
# model_type = "v1"
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
if config_files is not None and "v1" in config_files:
|
||||
original_config_file = config_files["v1"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||
|
||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||
# model_type = "v2"
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
|
||||
if config_files is not None and "v2" in config_files:
|
||||
original_config_file = config_files["v2"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||
if global_step == 110000:
|
||||
# v2.1 needs to upcast attention
|
||||
upcast_attention = True
|
||||
elif key_name_sd_xl_base in checkpoint:
|
||||
# only base xl has two text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
if config_files is not None and "xl" in config_files:
|
||||
original_config_file = config_files["xl"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||
elif key_name_sd_xl_refiner in checkpoint:
|
||||
# only refiner xl has embedder and one text embedders
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
if config_files is not None and "xl_refiner" in config_files:
|
||||
original_config_file = config_files["xl_refiner"]
|
||||
else:
|
||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||
if config_url is not None:
|
||||
original_config_file = BytesIO(requests.get(config_url).content)
|
||||
|
||||
original_config = OmegaConf.load(original_config_file)
|
||||
|
||||
|
||||
@@ -50,13 +50,26 @@ class SafetyConfig(object):
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {
|
||||
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
||||
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
||||
"safety_checker": ["StableDiffusionSafetyChecker"],
|
||||
}
|
||||
_import_structure = {}
|
||||
|
||||
_additional_imports.update({"SafetyConfig": SafetyConfig})
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure.update(
|
||||
{
|
||||
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
||||
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
||||
"safety_checker": ["StableDiffusionSafetyChecker"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
try:
|
||||
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
|
||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||
|
||||
else:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects
|
||||
import sys
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -162,7 +162,6 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
adapter_weights: Optional[List[float]] = None,
|
||||
requires_safety_checker: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -184,7 +183,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if isinstance(adapter, (list, tuple)):
|
||||
adapter = MultiAdapter(adapter, adapter_weights=adapter_weights)
|
||||
adapter = MultiAdapter(adapter)
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
@@ -727,9 +726,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7. Denoising loop
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if isinstance(self.adapter, MultiAdapter):
|
||||
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v
|
||||
else:
|
||||
adapter_state = self.adapter(adapter_input)
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v * adapter_conditioning_scale
|
||||
if num_images_per_prompt > 1:
|
||||
for k, v in enumerate(adapter_state):
|
||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||
|
||||
@@ -47,3 +47,5 @@ else:
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
f" {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -51,3 +51,6 @@ else:
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING:
|
||||
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
|
||||
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
|
||||
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -51,3 +50,6 @@ else:
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
|
||||
@@ -43,7 +43,7 @@ from diffusers.models.attention_processor import (
|
||||
LoRAAttnProcessor2_0,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.testing_utils import floats_tensor, require_torch_gpu, slow, torch_device
|
||||
from diffusers.utils.testing_utils import floats_tensor, nightly, require_torch_gpu, slow, torch_device
|
||||
|
||||
|
||||
def create_unet_lora_layers(unet: nn.Module):
|
||||
@@ -1464,3 +1464,41 @@ class LoraIntegrationTests(unittest.TestCase):
|
||||
expected = np.array([0.4468, 0.4087, 0.4134, 0.366, 0.3202, 0.3505, 0.3786, 0.387, 0.3535])
|
||||
|
||||
self.assertTrue(np.allclose(images, expected, atol=1e-3))
|
||||
|
||||
@nightly
|
||||
def test_sequential_fuse_unfuse(self):
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# 1. round
|
||||
pipe.load_lora_weights("Pclanglais/TintinIA")
|
||||
pipe.fuse_lora()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
image_slice = images[0, -3:, -3:, -1].flatten()
|
||||
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 2. round
|
||||
pipe.load_lora_weights("ProomptEngineer/pe-balloon-diffusion-style")
|
||||
pipe.fuse_lora()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 3. round
|
||||
pipe.load_lora_weights("ostris/crayon_style_lora_sdxl")
|
||||
pipe.fuse_lora()
|
||||
pipe.unfuse_lora()
|
||||
|
||||
# 4. back to 1st round
|
||||
pipe.load_lora_weights("Pclanglais/TintinIA")
|
||||
pipe.fuse_lora()
|
||||
|
||||
generator = torch.Generator().manual_seed(0)
|
||||
images_2 = pipe(
|
||||
"masterpiece, best quality, mountain", output_type="np", generator=generator, num_inference_steps=2
|
||||
).images
|
||||
image_slice_2 = images_2[0, -3:, -3:, -1].flatten()
|
||||
|
||||
self.assertTrue(np.allclose(image_slice, image_slice_2, atol=1e-3))
|
||||
|
||||
@@ -216,7 +216,9 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
|
||||
return super().get_dummy_components("multi_adapter")
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
return super().get_dummy_inputs(device, seed, num_images=2)
|
||||
inputs = super().get_dummy_inputs(device, seed, num_images=2)
|
||||
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
|
||||
return inputs
|
||||
|
||||
def test_stable_diffusion_adapter_default_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
|
||||
Reference in New Issue
Block a user