Compare commits

...

16 Commits

Author SHA1 Message Date
Sayak Paul
c37c840b1b update 2023-09-27 22:13:19 +05:30
Sayak Paul
9858053bfe Release: v0.21.3 2023-09-27 22:10:50 +05:30
Patrick von Platen
6a3301fe34 resolve conflicts. 2023-09-27 22:09:04 +05:30
Patrick von Platen
813a1b2ee0 Fix one more 2023-09-19 00:09:30 +02:00
Patrick von Platen
a43b8574a9 Patch release: v0.21.2 2023-09-19 00:06:16 +02:00
Sayak Paul
a2f0db52e3 [LoRA] don't break offloading for incompatible lora ckpts. (#5085)
* don't break offloading for incompatible lora ckpts.

* debugging

* better condition.

* fix

* fix

* fix

* fix

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
2023-09-19 00:05:34 +02:00
Will Berman
92f6693b37 remove unused adapter weights in constructor (#5088)
remove adapter weights in MultiAdapter constructor
2023-09-19 00:05:28 +02:00
Will Berman
932897afa8 t2i Adapter community member fix (#5090)
* convert tensorrt controlnet

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix code quality

* Fix number controlnet condition

* Add convert SD XL to onnx

* Add convert SD XL to tensorrt

* Add convert SD XL to tensorrt

* Add examples in comments

* Add examples in comments

* Add test onnx controlnet

* Add tensorrt test

* Remove copied

* Move file test to examples/community

* Remove script

* Remove script

* Remove text

* Fix import

* Fix T2I MultiAdapter

* fix tests

---------

Co-authored-by: dotieuthien <thien.do@mservice.com.vn>
Co-authored-by: dotieuthien <dotieuthien9997@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: dotieuthien <hades@cinnamon.is>
2023-09-19 00:05:20 +02:00
Patrick von Platen
c2940434d0 [Textual inversion] Refactor textual inversion to make it cleaner (#5076)
* [Textual inversion] Clean loading

* [Textual inversion] Clean loading

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up

* [Textual inversion] Clean up
2023-09-19 00:04:53 +02:00
Patrick von Platen
60ab8fad16 Patch release: v0.21.1 2023-09-14 13:06:57 +02:00
Patrick von Platen
d17240457f [Import] Add missing settings / Correct some dummy imports (#5036)
* [Import] Add missing settings

* up

* up

* up
2023-09-14 12:47:55 +02:00
Vladimir Mandic
7512fc4df5 allow loading of sd models from safetensors without online lookups using local config files (#5019)
finish config_files implementation
2023-09-14 12:47:41 +02:00
Patrick von Platen
0c2f1ccc97 [Import] Don't force transformers to be installed (#5035)
* [Import] Don't force transformers to be installed

* make style
2023-09-14 12:47:34 +02:00
Dhruv Nair
47f2d2c7be Fix model offload bug when key isn't present (#5030)
* fix model offload bug when key isn't present

* make style
2023-09-14 12:47:25 +02:00
Patrick von Platen
af85591593 Patch release: v0.21.1 2023-09-14 12:46:39 +02:00
Patrick von Platen
29f15673ed Release: v0.21.0 2023-09-13 15:58:24 +02:00
40 changed files with 576 additions and 295 deletions

View File

@@ -56,7 +56,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -59,7 +59,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -60,7 +60,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
# Cache compiled models across invocations of this script. # Cache compiled models across invocations of this script.
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache")) cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))

View File

@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -58,7 +58,7 @@ if is_wandb_available():
import wandb import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -53,7 +53,7 @@ if is_wandb_available():
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -79,7 +79,7 @@ else:
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__) logger = get_logger(__name__)

View File

@@ -56,7 +56,7 @@ else:
# ------------------------------------------------------------------------------ # ------------------------------------------------------------------------------
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
# Will error if the minimal version of diffusers is not installed. Remove at your own risks. # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.21.0.dev0") check_min_version("0.21.0")
logger = get_logger(__name__, log_level="INFO") logger = get_logger(__name__, log_level="INFO")

View File

@@ -154,6 +154,7 @@ if __name__ == "__main__":
pipe = download_from_original_stable_diffusion_ckpt( pipe = download_from_original_stable_diffusion_ckpt(
checkpoint_path_or_dict=args.checkpoint_path, checkpoint_path_or_dict=args.checkpoint_path,
original_config_file=args.original_config_file, original_config_file=args.original_config_file,
config_files=args.config_files,
image_size=args.image_size, image_size=args.image_size,
prediction_type=args.prediction_type, prediction_type=args.prediction_type,
model_type=args.pipeline_type, model_type=args.pipeline_type,

View File

@@ -244,7 +244,7 @@ install_requires = [
setup( setup(
name="diffusers", name="diffusers",
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) version="0.21.3", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.", description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(), long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown", long_description_content_type="text/markdown",

View File

@@ -1,4 +1,4 @@
__version__ = "0.21.0.dev0" __version__ = "0.21.3"
from typing import TYPE_CHECKING from typing import TYPE_CHECKING

View File

@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import re import re
import warnings
from collections import defaultdict from collections import defaultdict
from contextlib import nullcontext from contextlib import nullcontext
from io import BytesIO from io import BytesIO
@@ -41,7 +40,7 @@ from .utils.import_utils import BACKENDS_MAPPING
if is_transformers_available(): if is_transformers_available():
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer from transformers import CLIPTextModel, CLIPTextModelWithProjection
if is_accelerate_available(): if is_accelerate_available():
from accelerate import init_empty_weights from accelerate import init_empty_weights
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alphas = kwargs.pop("network_alphas", None) network_alphas = kwargs.pop("network_alphas", None)
_pipeline = kwargs.pop("_pipeline", None)
is_network_alphas_none = network_alphas is None is_network_alphas_none = network_alphas is None
allow_pickle = False allow_pickle = False
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype) load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
else: else:
lora.load_state_dict(value_dict) lora.load_state_dict(value_dict)
elif is_custom_diffusion: elif is_custom_diffusion:
attn_processors = {} attn_processors = {}
custom_diffusion_grouped_dict = defaultdict(dict) custom_diffusion_grouped_dict = defaultdict(dict)
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
) )
attn_processors[key].load_state_dict(value_dict) attn_processors[key].load_state_dict(value_dict)
self.set_attn_processor(attn_processors)
else: else:
raise ValueError( raise ValueError(
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training." f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
) )
# <Unsafe code
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
# only custom diffusion needs to set attn processors
if is_custom_diffusion:
self.set_attn_processor(attn_processors)
# set lora layers # set lora layers
for target_module, lora_layer in lora_layers_list: for target_module, lora_layer in lora_layers_list:
target_module.set_lora_layer(lora_layer) target_module.set_lora_layer(lora_layer)
self.to(dtype=self.dtype, device=self.device) self.to(dtype=self.dtype, device=self.device)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas): def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
is_new_lora_format = all( is_new_lora_format = all(
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys() key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
@@ -622,12 +650,87 @@ class UNet2DConditionLoadersMixin:
module._unfuse_lora() module._unfuse_lora()
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
state_dicts = []
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 3.1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
state_dicts.append(state_dict)
return state_dicts
class TextualInversionLoaderMixin: class TextualInversionLoaderMixin:
r""" r"""
Load textual inversion tokens and embeddings to the tokenizer and text encoder. Load textual inversion tokens and embeddings to the tokenizer and text encoder.
""" """
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
r""" r"""
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
@@ -654,7 +757,7 @@ class TextualInversionLoaderMixin:
return prompts return prompts
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
r""" r"""
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
@@ -684,12 +787,103 @@ class TextualInversionLoaderMixin:
return prompt return prompt
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
if tokenizer is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if text_encoder is None:
raise ValueError(
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
@staticmethod
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
all_tokens = []
all_embeddings = []
for state_dict, token in zip(state_dicts, tokens):
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
loaded_token = token
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
else:
raise ValueError(
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
" input key."
)
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
if token in tokenizer.get_vocab():
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
all_tokens.append(token)
all_embeddings.append(embedding)
return all_tokens, all_embeddings
@staticmethod
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
all_tokens = []
all_embeddings = []
for embedding, token in zip(embeddings, tokens):
if f"{token}_1" in tokenizer.get_vocab():
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
all_embeddings += [e for e in embedding] # noqa: C416
else:
all_tokens += [token]
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
return all_tokens, all_embeddings
def load_textual_inversion( def load_textual_inversion(
self, self,
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]], pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
token: Optional[Union[str, List[str]]] = None, token: Optional[Union[str, List[str]]] = None,
tokenizer: Optional[PreTrainedTokenizer] = None, tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
text_encoder: Optional[PreTrainedModel] = None, text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
**kwargs, **kwargs,
): ):
r""" r"""
@@ -789,25 +983,44 @@ class TextualInversionLoaderMixin:
``` ```
""" """
# 1. Set correct tokenizer and text encoder
tokenizer = tokenizer or getattr(self, "tokenizer", None) tokenizer = tokenizer or getattr(self, "tokenizer", None)
text_encoder = text_encoder or getattr(self, "text_encoder", None) text_encoder = text_encoder or getattr(self, "text_encoder", None)
if tokenizer is None: # 2. Normalize inputs
pretrained_model_name_or_paths = (
[pretrained_model_name_or_path]
if not isinstance(pretrained_model_name_or_path, list)
else pretrained_model_name_or_path
)
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
# 3. Check inputs
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
# 4. Load state dicts of textual embeddings
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
# 4. Retrieve tokens and embeddings
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
# 5. Extend tokens and embeddings for multi vector
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
# 6. Make sure all embeddings have the correct size
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
raise ValueError( raise ValueError(
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling" "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
f" `{self.load_textual_inversion.__name__}`" "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
) )
if text_encoder is None: # 7. Now we can be sure that loading the embedding matrix works
raise ValueError( # < Unsafe code:
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
f" `{self.load_textual_inversion.__name__}`"
)
# Remove any existing hooks. # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
is_model_cpu_offload = False is_model_cpu_offload = False
is_sequential_cpu_offload = False is_sequential_cpu_offload = False
recursive = False
for _, component in self.components.items(): for _, component in self.components.items():
if isinstance(component, nn.Module): if isinstance(component, nn.Module):
if hasattr(component, "_hf_hook"): if hasattr(component, "_hf_hook"):
@@ -816,168 +1029,34 @@ class TextualInversionLoaderMixin:
logger.info( logger.info(
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again." "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
) )
recursive = is_sequential_cpu_offload remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
remove_hook_from_module(component, recurse=recursive)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) # 7.2 save expected device and dtype
force_download = kwargs.pop("force_download", False) device = text_encoder.device
resume_download = kwargs.pop("resume_download", False) dtype = text_encoder.dtype
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
allow_pickle = False # 7.3 Increase token embedding matrix
if use_safetensors is None: text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
use_safetensors = True input_embeddings = text_encoder.get_input_embeddings().weight
allow_pickle = True
user_agent = {
"file_type": "text_inversion",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path, list):
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
else:
pretrained_model_name_or_paths = pretrained_model_name_or_path
if isinstance(token, str):
tokens = [token]
elif token is None:
tokens = [None] * len(pretrained_model_name_or_paths)
else:
tokens = token
if len(pretrained_model_name_or_paths) != len(tokens):
raise ValueError(
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
f"Make sure both lists have the same length."
)
valid_tokens = [t for t in tokens if t is not None]
if len(set(valid_tokens)) < len(valid_tokens):
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
token_ids_and_embeddings = []
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
# 1. Load textual inversion file
model_file = None
# Let's first try to load .safetensors weights
if (use_safetensors and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except Exception as e:
if not allow_pickle:
raise e
model_file = None
if model_file is None:
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=weight_name or TEXT_INVERSION_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path
# 2. Load token and embedding correcly from file
loaded_token = None
if isinstance(state_dict, torch.Tensor):
if token is None:
raise ValueError(
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
)
embedding = state_dict
elif len(state_dict) == 1:
# diffusers
loaded_token, embedding = next(iter(state_dict.items()))
elif "string_to_param" in state_dict:
# A1111
loaded_token = state_dict["name"]
embedding = state_dict["string_to_param"]["*"]
if token is not None and loaded_token != token:
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
else:
token = loaded_token
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
# 3. Make sure we don't mess up the tokenizer or text encoder
vocab = tokenizer.get_vocab()
if token in vocab:
raise ValueError(
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
)
elif f"{token}_1" in vocab:
multi_vector_tokens = [token]
i = 1
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
multi_vector_tokens.append(f"{token}_{i}")
i += 1
raise ValueError(
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
)
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
if is_multi_vector:
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
embeddings = [e for e in embedding] # noqa: C416
else:
tokens = [token]
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
# 7.4 Load token and embedding
for token, embedding in zip(tokens, embeddings):
# add tokens and get ids # add tokens and get ids
tokenizer.add_tokens(tokens) tokenizer.add_tokens(token)
token_ids = tokenizer.convert_tokens_to_ids(tokens) token_id = tokenizer.convert_tokens_to_ids(token)
token_ids_and_embeddings += zip(token_ids, embeddings) input_embeddings.data[token_id] = embedding
logger.info(f"Loaded textual inversion embedding for {token}.") logger.info(f"Loaded textual inversion embedding for {token}.")
# resize token embeddings and set all new embeddings input_embeddings.to(dtype=dtype, device=device)
text_encoder.resize_token_embeddings(len(tokenizer))
for token_id, embedding in token_ids_and_embeddings:
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
# offload back # 7.5 Offload the model again
if is_model_cpu_offload: if is_model_cpu_offload:
self.enable_model_cpu_offload() self.enable_model_cpu_offload()
elif is_sequential_cpu_offload: elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload() self.enable_sequential_cpu_offload()
# / Unsafe Code >
class LoraLoaderMixin: class LoraLoaderMixin:
r""" r"""
@@ -1009,26 +1088,21 @@ class LoraLoaderMixin:
kwargs (`dict`, *optional*): kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`]. See [`~loaders.LoraLoaderMixin.lora_state_dict`].
""" """
# Remove any existing hooks. # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
is_model_cpu_offload = False state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_sequential_cpu_offload = False
recurive = False is_correct_format = all("lora" in key for key in state_dict.keys())
for _, component in self.components.items(): if not is_correct_format:
if isinstance(component, nn.Module): raise ValueError("Invalid LoRA checkpoint.")
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
recurive = is_sequential_cpu_offload
remove_hook_from_module(component, recurse=recurive)
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet( self.load_lora_into_unet(
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage state_dict,
network_alphas=network_alphas,
unet=self.unet,
low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
) )
self.load_lora_into_text_encoder( self.load_lora_into_text_encoder(
state_dict, state_dict,
@@ -1036,14 +1110,9 @@ class LoraLoaderMixin:
text_encoder=self.text_encoder, text_encoder=self.text_encoder,
lora_scale=self.lora_scale, lora_scale=self.lora_scale,
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
_pipeline=self,
) )
# Offload back.
if is_model_cpu_offload:
self.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
self.enable_sequential_cpu_offload()
@classmethod @classmethod
def lora_state_dict( def lora_state_dict(
cls, cls,
@@ -1340,7 +1409,7 @@ class LoraLoaderMixin:
return new_state_dict return new_state_dict
@classmethod @classmethod
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None): def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
""" """
This will load the LoRA layers specified in `state_dict` into `unet`. This will load the LoRA layers specified in `state_dict` into `unet`.
@@ -1382,13 +1451,22 @@ class LoraLoaderMixin:
# Otherwise, we're dealing with the old format. This means the `state_dict` should only # Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix. # contain the module names of the `unet` as its keys WITHOUT any prefix.
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`." warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message) logger.warn(warn_message)
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage) unet.load_attn_procs(
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
)
@classmethod @classmethod
def load_lora_into_text_encoder( def load_lora_into_text_encoder(
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None cls,
state_dict,
network_alphas,
text_encoder,
prefix=None,
lora_scale=1.0,
low_cpu_mem_usage=None,
_pipeline=None,
): ):
""" """
This will load the LoRA layers specified in `state_dict` into `text_encoder` This will load the LoRA layers specified in `state_dict` into `text_encoder`
@@ -1498,11 +1576,15 @@ class LoraLoaderMixin:
low_cpu_mem_usage=low_cpu_mem_usage, low_cpu_mem_usage=low_cpu_mem_usage,
) )
# set correct dtype & device is_pipeline_offloaded = _pipeline is not None and any(
text_encoder_lora_state_dict = { isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype) )
for k, v in text_encoder_lora_state_dict.items() if is_pipeline_offloaded and low_cpu_mem_usage:
} low_cpu_mem_usage = True
logger.info(
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
)
if low_cpu_mem_usage: if low_cpu_mem_usage:
device = next(iter(text_encoder_lora_state_dict.values())).device device = next(iter(text_encoder_lora_state_dict.values())).device
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
@@ -1518,8 +1600,33 @@ class LoraLoaderMixin:
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}" f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
) )
# <Unsafe code
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
# Now we remove any existing hooks to
is_model_cpu_offload = False
is_sequential_cpu_offload = False
if _pipeline is not None:
for _, component in _pipeline.components.items():
if isinstance(component, torch.nn.Module):
if hasattr(component, "_hf_hook"):
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
is_sequential_cpu_offload = isinstance(
getattr(component, "_hf_hook"), AlignDevicesHook
)
logger.info(
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
)
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype) text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
# Offload back.
if is_model_cpu_offload:
_pipeline.enable_model_cpu_offload()
elif is_sequential_cpu_offload:
_pipeline.enable_sequential_cpu_offload()
# Unsafe code />
@property @property
def lora_scale(self) -> float: def lora_scale(self) -> float:
# property function that returns the lora scale which can be set at run time by the pipeline. # property function that returns the lora scale which can be set at run time by the pipeline.
@@ -2098,6 +2205,7 @@ class FromSingleFileMixin:
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
original_config_file = kwargs.pop("original_config_file", None) original_config_file = kwargs.pop("original_config_file", None)
config_files = kwargs.pop("config_files", None)
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
resume_download = kwargs.pop("resume_download", False) resume_download = kwargs.pop("resume_download", False)
force_download = kwargs.pop("force_download", False) force_download = kwargs.pop("force_download", False)
@@ -2215,6 +2323,7 @@ class FromSingleFileMixin:
vae=vae, vae=vae,
tokenizer=tokenizer, tokenizer=tokenizer,
original_config_file=original_config_file, original_config_file=original_config_file,
config_files=config_files,
) )
if torch_dtype is not None: if torch_dtype is not None:
@@ -2556,3 +2665,132 @@ class FromOriginalControlnetMixin:
controlnet.to(torch_dtype=torch_dtype) controlnet.to(torch_dtype=torch_dtype)
return controlnet return controlnet
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
"""
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
`self.text_encoder`.
All kwargs are forwarded to `self.lora_state_dict`.
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
`self.unet`.
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
into `self.text_encoder`.
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
kwargs (`dict`, *optional*):
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
"""
# We could have accessed the unet config from `lora_state_dict()` too. We pass
# it here explicitly to be able to tell that it's coming from an SDXL
# pipeline.
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
state_dict, network_alphas = self.lora_state_dict(
pretrained_model_name_or_path_or_dict,
unet_config=self.unet.config,
**kwargs,
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
_pipeline=self,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alphas=network_alphas,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
_pipeline=self,
)
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
):
r"""
Save the LoRA parameters corresponding to the UNet and text encoder.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to save LoRA parameters to. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `unet`.
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
encoder LoRA state dict because it comes from 🤗 Transformers.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful during distributed training and you
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful during distributed training when you need to
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
"""
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
raise ValueError(
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
)
if unet_lora_layers:
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -90,6 +90,8 @@ class MultiAdapter(ModelMixin):
features = adapter(x) features = adapter(x)
if accume_state is None: if accume_state is None:
accume_state = features accume_state = features
for i in range(len(accume_state)):
accume_state[i] = w * accume_state[i]
else: else:
for i in range(len(features)): for i in range(len(features)):
accume_state[i] += w * features[i] accume_state[i] += w * features[i]

View File

@@ -304,19 +304,16 @@ class Attention(nn.Module):
self.set_processor(processor) self.set_processor(processor)
def set_processor(self, processor: "AttnProcessor"): def set_processor(self, processor: "AttnProcessor", _remove_lora=False):
if ( if hasattr(self, "processor") and _remove_lora and self.to_q.lora_layer is not None:
hasattr(self, "processor")
and not isinstance(processor, LORA_ATTENTION_PROCESSORS)
and self.to_q.lora_layer is not None
):
deprecate( deprecate(
"set_processor to offload LoRA", "set_processor to offload LoRA",
"0.26.0", "0.26.0",
"In detail, removing LoRA layers via calling `set_processor` or `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.", "In detail, removing LoRA layers via calling `set_default_attn_processor` is deprecated. Please make sure to call `pipe.unload_lora_weights()` instead.",
) )
# TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete # TODO(Patrick, Sayak) - this can be deprecated once PEFT LoRA integration is complete
# We need to remove all LoRA layers # We need to remove all LoRA layers
# Don't forget to remove ALL `_remove_lora` from the codebase
for module in self.modules(): for module in self.modules():
if hasattr(module, "set_lora_layer"): if hasattr(module, "set_lora_layer"):
module.set_lora_layer(None) module.set_lora_layer(None)

View File

@@ -196,7 +196,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -220,9 +222,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -244,7 +246,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
@apply_forward_hook @apply_forward_hook
def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:

View File

@@ -517,7 +517,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -541,9 +543,9 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -565,7 +567,7 @@ class ControlNetModel(ModelMixin, ConfigMixin, FromOriginalControlnetMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):

View File

@@ -191,7 +191,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -215,9 +217,9 @@ class PriorTransformer(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -239,7 +241,7 @@ class PriorTransformer(ModelMixin, ConfigMixin):
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def forward( def forward(
self, self,

View File

@@ -613,7 +613,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
return processors return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -637,9 +639,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -660,7 +662,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""

View File

@@ -366,7 +366,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
fn_recursive_set_attention_slice(module, reversed_slice_size) fn_recursive_set_attention_slice(module, reversed_slice_size)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -390,9 +392,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -454,7 +456,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def _set_gradient_checkpointing(self, module, value=False): def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):

View File

@@ -538,7 +538,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
return processors return processors
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -562,9 +564,9 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -586,7 +588,7 @@ class AudioLDM2UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoad
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attention_slice
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):

View File

@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
self._all_hooks = [] self._all_hooks = []
hook = None hook = None
for model_str in self.model_cpu_offload_seq.split("->"): for model_str in self.model_cpu_offload_seq.split("->"):
model = all_model_components.pop(model_str) model = all_model_components.pop(model_str, None)
if not isinstance(model, torch.nn.Module): if not isinstance(model, torch.nn.Module):
continue continue

View File

@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight" key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias" key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
config_url = None
# model_type = "v1" # model_type = "v1"
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" if config_files is not None and "v1" in config_files:
original_config_file = config_files["v1"]
else:
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024: if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
# model_type = "v2" # model_type = "v2"
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml" if config_files is not None and "v2" in config_files:
original_config_file = config_files["v2"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
if global_step == 110000: if global_step == 110000:
# v2.1 needs to upcast attention # v2.1 needs to upcast attention
upcast_attention = True upcast_attention = True
elif key_name_sd_xl_base in checkpoint: elif key_name_sd_xl_base in checkpoint:
# only base xl has two text embedders # only base xl has two text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml" if config_files is not None and "xl" in config_files:
original_config_file = config_files["xl"]
else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
elif key_name_sd_xl_refiner in checkpoint: elif key_name_sd_xl_refiner in checkpoint:
# only refiner xl has embedder and one text embedders # only refiner xl has embedder and one text embedders
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml" if config_files is not None and "xl_refiner" in config_files:
original_config_file = config_files["xl_refiner"]
original_config_file = BytesIO(requests.get(config_url).content) else:
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config = OmegaConf.load(original_config_file) original_config = OmegaConf.load(original_config_file)

View File

@@ -50,13 +50,26 @@ class SafetyConfig(object):
_dummy_objects = {} _dummy_objects = {}
_additional_imports = {} _additional_imports = {}
_import_structure = { _import_structure = {}
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
_additional_imports.update({"SafetyConfig": SafetyConfig}) _additional_imports.update({"SafetyConfig": SafetyConfig})
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure.update(
{
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
"safety_checker": ["StableDiffusionSafetyChecker"],
}
)
if TYPE_CHECKING: if TYPE_CHECKING:
try: try:
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
from .safety_checker import SafeStableDiffusionSafetyChecker from .safety_checker import SafeStableDiffusionSafetyChecker
else: else:
try: import sys
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
else: for name, value in _dummy_objects.items():
import sys setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
sys.modules[__name__] = _LazyModule( setattr(sys.modules[__name__], name, value)
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -162,7 +162,6 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
scheduler: KarrasDiffusionSchedulers, scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker, safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor, feature_extractor: CLIPFeatureExtractor,
adapter_weights: Optional[List[float]] = None,
requires_safety_checker: bool = True, requires_safety_checker: bool = True,
): ):
super().__init__() super().__init__()
@@ -184,7 +183,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
) )
if isinstance(adapter, (list, tuple)): if isinstance(adapter, (list, tuple)):
adapter = MultiAdapter(adapter, adapter_weights=adapter_weights) adapter = MultiAdapter(adapter)
self.register_modules( self.register_modules(
vae=vae, vae=vae,
@@ -727,9 +726,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7. Denoising loop # 7. Denoising loop
adapter_state = self.adapter(adapter_input) if isinstance(self.adapter, MultiAdapter):
for k, v in enumerate(adapter_state): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
adapter_state[k] = v * adapter_conditioning_scale for k, v in enumerate(adapter_state):
adapter_state[k] = v
else:
adapter_state = self.adapter(adapter_input)
for k, v in enumerate(adapter_state):
adapter_state[k] = v * adapter_conditioning_scale
if num_images_per_prompt > 1: if num_images_per_prompt > 1:
for k, v in enumerate(adapter_state): for k, v in enumerate(adapter_state):
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)

View File

@@ -47,3 +47,5 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -820,7 +820,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
return processors return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
):
r""" r"""
Sets the attention processor to use to compute attention. Sets the attention processor to use to compute attention.
@@ -844,9 +846,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"): if hasattr(module, "set_processor"):
if not isinstance(processor, dict): if not isinstance(processor, dict):
module.set_processor(processor) module.set_processor(processor, _remove_lora=_remove_lora)
else: else:
module.set_processor(processor.pop(f"{name}.processor")) module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
for sub_name, child in module.named_children(): for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
@@ -868,7 +870,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
f" {next(iter(self.attn_processors.values()))}" f" {next(iter(self.attn_processors.values()))}"
) )
self.set_attn_processor(processor) self.set_attn_processor(processor, _remove_lora=True)
def set_attention_slice(self, slice_size): def set_attention_slice(self, slice_size):
r""" r"""

View File

@@ -51,3 +51,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -41,7 +41,6 @@ if TYPE_CHECKING:
from .pipeline_wuerstchen import WuerstchenDecoderPipeline from .pipeline_wuerstchen import WuerstchenDecoderPipeline
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
else: else:
import sys import sys
@@ -51,3 +50,6 @@ else:
_import_structure, _import_structure,
module_spec=__spec__, module_spec=__spec__,
) )
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -216,7 +216,9 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
return super().get_dummy_components("multi_adapter") return super().get_dummy_components("multi_adapter")
def get_dummy_inputs(self, device, seed=0): def get_dummy_inputs(self, device, seed=0):
return super().get_dummy_inputs(device, seed, num_images=2) inputs = super().get_dummy_inputs(device, seed, num_images=2)
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
return inputs
def test_stable_diffusion_adapter_default_case(self): def test_stable_diffusion_adapter_default_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator device = "cpu" # ensure determinism for the device-dependent torch.Generator