mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 15:34:17 +08:00
Compare commits
13 Commits
controlnet
...
v0.21.2-pa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
813a1b2ee0 | ||
|
|
a43b8574a9 | ||
|
|
a2f0db52e3 | ||
|
|
92f6693b37 | ||
|
|
932897afa8 | ||
|
|
c2940434d0 | ||
|
|
60ab8fad16 | ||
|
|
d17240457f | ||
|
|
7512fc4df5 | ||
|
|
0c2f1ccc97 | ||
|
|
47f2d2c7be | ||
|
|
af85591593 | ||
|
|
29f15673ed |
@@ -56,7 +56,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -59,7 +59,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
# Cache compiled models across invocations of this script.
|
# Cache compiled models across invocations of this script.
|
||||||
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
cc.initialize_cache(os.path.expanduser("~/.cache/jax/compilation_cache"))
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_wandb_available():
|
|||||||
import wandb
|
import wandb
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ if is_wandb_available():
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -33,7 +33,7 @@ from diffusers.utils import check_min_version
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -48,7 +48,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ else:
|
|||||||
# ------------------------------------------------------------------------------
|
# ------------------------------------------------------------------------------
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
|
|
||||||
|
|
||||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||||
check_min_version("0.21.0.dev0")
|
check_min_version("0.21.0")
|
||||||
|
|
||||||
logger = get_logger(__name__, log_level="INFO")
|
logger = get_logger(__name__, log_level="INFO")
|
||||||
|
|
||||||
|
|||||||
@@ -154,6 +154,7 @@ if __name__ == "__main__":
|
|||||||
pipe = download_from_original_stable_diffusion_ckpt(
|
pipe = download_from_original_stable_diffusion_ckpt(
|
||||||
checkpoint_path_or_dict=args.checkpoint_path,
|
checkpoint_path_or_dict=args.checkpoint_path,
|
||||||
original_config_file=args.original_config_file,
|
original_config_file=args.original_config_file,
|
||||||
|
config_files=args.config_files,
|
||||||
image_size=args.image_size,
|
image_size=args.image_size,
|
||||||
prediction_type=args.prediction_type,
|
prediction_type=args.prediction_type,
|
||||||
model_type=args.pipeline_type,
|
model_type=args.pipeline_type,
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -244,7 +244,7 @@ install_requires = [
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="diffusers",
|
name="diffusers",
|
||||||
version="0.21.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
version="0.21.2", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||||
long_description_content_type="text/markdown",
|
long_description_content_type="text/markdown",
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
__version__ = "0.21.0.dev0"
|
__version__ = "0.21.2"
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import warnings
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
@@ -41,7 +40,7 @@ from .utils.import_utils import BACKENDS_MAPPING
|
|||||||
|
|
||||||
|
|
||||||
if is_transformers_available():
|
if is_transformers_available():
|
||||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||||
|
|
||||||
if is_accelerate_available():
|
if is_accelerate_available():
|
||||||
from accelerate import init_empty_weights
|
from accelerate import init_empty_weights
|
||||||
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
|
|||||||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
||||||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
||||||
network_alphas = kwargs.pop("network_alphas", None)
|
network_alphas = kwargs.pop("network_alphas", None)
|
||||||
|
|
||||||
|
_pipeline = kwargs.pop("_pipeline", None)
|
||||||
|
|
||||||
is_network_alphas_none = network_alphas is None
|
is_network_alphas_none = network_alphas is None
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
|
|||||||
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
lora.load_state_dict(value_dict)
|
lora.load_state_dict(value_dict)
|
||||||
|
|
||||||
elif is_custom_diffusion:
|
elif is_custom_diffusion:
|
||||||
attn_processors = {}
|
attn_processors = {}
|
||||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||||
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
|
|||||||
cross_attention_dim=cross_attention_dim,
|
cross_attention_dim=cross_attention_dim,
|
||||||
)
|
)
|
||||||
attn_processors[key].load_state_dict(value_dict)
|
attn_processors[key].load_state_dict(value_dict)
|
||||||
|
|
||||||
self.set_attn_processor(attn_processors)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# <Unsafe code
|
||||||
|
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||||
|
# Now we remove any existing hooks to
|
||||||
|
is_model_cpu_offload = False
|
||||||
|
is_sequential_cpu_offload = False
|
||||||
|
if _pipeline is not None:
|
||||||
|
for _, component in _pipeline.components.items():
|
||||||
|
if isinstance(component, nn.Module):
|
||||||
|
if hasattr(component, "_hf_hook"):
|
||||||
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||||
|
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
||||||
|
logger.info(
|
||||||
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||||
|
)
|
||||||
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||||
|
|
||||||
|
# only custom diffusion needs to set attn processors
|
||||||
|
if is_custom_diffusion:
|
||||||
|
self.set_attn_processor(attn_processors)
|
||||||
|
|
||||||
# set lora layers
|
# set lora layers
|
||||||
for target_module, lora_layer in lora_layers_list:
|
for target_module, lora_layer in lora_layers_list:
|
||||||
target_module.set_lora_layer(lora_layer)
|
target_module.set_lora_layer(lora_layer)
|
||||||
|
|
||||||
self.to(dtype=self.dtype, device=self.device)
|
self.to(dtype=self.dtype, device=self.device)
|
||||||
|
|
||||||
|
# Offload back.
|
||||||
|
if is_model_cpu_offload:
|
||||||
|
_pipeline.enable_model_cpu_offload()
|
||||||
|
elif is_sequential_cpu_offload:
|
||||||
|
_pipeline.enable_sequential_cpu_offload()
|
||||||
|
# Unsafe code />
|
||||||
|
|
||||||
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
||||||
is_new_lora_format = all(
|
is_new_lora_format = all(
|
||||||
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
||||||
@@ -622,12 +650,87 @@ class UNet2DConditionLoadersMixin:
|
|||||||
module._unfuse_lora()
|
module._unfuse_lora()
|
||||||
|
|
||||||
|
|
||||||
|
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
||||||
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||||
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
weight_name = kwargs.pop("weight_name", None)
|
||||||
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
|
allow_pickle = False
|
||||||
|
if use_safetensors is None:
|
||||||
|
use_safetensors = True
|
||||||
|
allow_pickle = True
|
||||||
|
|
||||||
|
user_agent = {
|
||||||
|
"file_type": "text_inversion",
|
||||||
|
"framework": "pytorch",
|
||||||
|
}
|
||||||
|
state_dicts = []
|
||||||
|
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
|
||||||
|
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
||||||
|
# 3.1. Load textual inversion file
|
||||||
|
model_file = None
|
||||||
|
|
||||||
|
# Let's first try to load .safetensors weights
|
||||||
|
if (use_safetensors and weight_name is None) or (
|
||||||
|
weight_name is not None and weight_name.endswith(".safetensors")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
||||||
|
except Exception as e:
|
||||||
|
if not allow_pickle:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
model_file = None
|
||||||
|
|
||||||
|
if model_file is None:
|
||||||
|
model_file = _get_model_file(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=weight_name or TEXT_INVERSION_NAME,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
use_auth_token=use_auth_token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
)
|
||||||
|
state_dict = torch.load(model_file, map_location="cpu")
|
||||||
|
else:
|
||||||
|
state_dict = pretrained_model_name_or_path
|
||||||
|
|
||||||
|
state_dicts.append(state_dict)
|
||||||
|
|
||||||
|
return state_dicts
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionLoaderMixin:
|
class TextualInversionLoaderMixin:
|
||||||
r"""
|
r"""
|
||||||
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"):
|
def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||||
r"""
|
r"""
|
||||||
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
|
||||||
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
|
||||||
@@ -654,7 +757,7 @@ class TextualInversionLoaderMixin:
|
|||||||
|
|
||||||
return prompts
|
return prompts
|
||||||
|
|
||||||
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"):
|
def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
|
||||||
r"""
|
r"""
|
||||||
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
|
||||||
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
|
||||||
@@ -684,12 +787,103 @@ class TextualInversionLoaderMixin:
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
|
||||||
|
if tokenizer is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
||||||
|
f" `{self.load_textual_inversion.__name__}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if text_encoder is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
||||||
|
f" `{self.load_textual_inversion.__name__}`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(pretrained_model_name_or_paths) != len(tokens):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
|
||||||
|
f"Make sure both lists have the same length."
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_tokens = [t for t in tokens if t is not None]
|
||||||
|
if len(set(valid_tokens)) < len(valid_tokens):
|
||||||
|
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
|
||||||
|
all_tokens = []
|
||||||
|
all_embeddings = []
|
||||||
|
for state_dict, token in zip(state_dicts, tokens):
|
||||||
|
if isinstance(state_dict, torch.Tensor):
|
||||||
|
if token is None:
|
||||||
|
raise ValueError(
|
||||||
|
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
||||||
|
)
|
||||||
|
loaded_token = token
|
||||||
|
embedding = state_dict
|
||||||
|
elif len(state_dict) == 1:
|
||||||
|
# diffusers
|
||||||
|
loaded_token, embedding = next(iter(state_dict.items()))
|
||||||
|
elif "string_to_param" in state_dict:
|
||||||
|
# A1111
|
||||||
|
loaded_token = state_dict["name"]
|
||||||
|
embedding = state_dict["string_to_param"]["*"]
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
|
||||||
|
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
|
||||||
|
" input key."
|
||||||
|
)
|
||||||
|
|
||||||
|
if token is not None and loaded_token != token:
|
||||||
|
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
||||||
|
else:
|
||||||
|
token = loaded_token
|
||||||
|
|
||||||
|
if token in tokenizer.get_vocab():
|
||||||
|
raise ValueError(
|
||||||
|
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
||||||
|
)
|
||||||
|
|
||||||
|
all_tokens.append(token)
|
||||||
|
all_embeddings.append(embedding)
|
||||||
|
|
||||||
|
return all_tokens, all_embeddings
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
|
||||||
|
all_tokens = []
|
||||||
|
all_embeddings = []
|
||||||
|
|
||||||
|
for embedding, token in zip(embeddings, tokens):
|
||||||
|
if f"{token}_1" in tokenizer.get_vocab():
|
||||||
|
multi_vector_tokens = [token]
|
||||||
|
i = 1
|
||||||
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
||||||
|
multi_vector_tokens.append(f"{token}_{i}")
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
raise ValueError(
|
||||||
|
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
||||||
|
)
|
||||||
|
|
||||||
|
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
||||||
|
if is_multi_vector:
|
||||||
|
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
||||||
|
all_embeddings += [e for e in embedding] # noqa: C416
|
||||||
|
else:
|
||||||
|
all_tokens += [token]
|
||||||
|
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
||||||
|
|
||||||
|
return all_tokens, all_embeddings
|
||||||
|
|
||||||
def load_textual_inversion(
|
def load_textual_inversion(
|
||||||
self,
|
self,
|
||||||
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
||||||
token: Optional[Union[str, List[str]]] = None,
|
token: Optional[Union[str, List[str]]] = None,
|
||||||
tokenizer: Optional[PreTrainedTokenizer] = None,
|
tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
|
||||||
text_encoder: Optional[PreTrainedModel] = None,
|
text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
r"""
|
||||||
@@ -789,25 +983,44 @@ class TextualInversionLoaderMixin:
|
|||||||
```
|
```
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# 1. Set correct tokenizer and text encoder
|
||||||
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
||||||
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
||||||
|
|
||||||
if tokenizer is None:
|
# 2. Normalize inputs
|
||||||
|
pretrained_model_name_or_paths = (
|
||||||
|
[pretrained_model_name_or_path]
|
||||||
|
if not isinstance(pretrained_model_name_or_path, list)
|
||||||
|
else pretrained_model_name_or_path
|
||||||
|
)
|
||||||
|
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
|
||||||
|
|
||||||
|
# 3. Check inputs
|
||||||
|
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
|
||||||
|
|
||||||
|
# 4. Load state dicts of textual embeddings
|
||||||
|
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
||||||
|
|
||||||
|
# 4. Retrieve tokens and embeddings
|
||||||
|
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
|
||||||
|
|
||||||
|
# 5. Extend tokens and embeddings for multi vector
|
||||||
|
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
|
||||||
|
|
||||||
|
# 6. Make sure all embeddings have the correct size
|
||||||
|
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
|
||||||
|
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
|
||||||
f" `{self.load_textual_inversion.__name__}`"
|
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
|
||||||
)
|
)
|
||||||
|
|
||||||
if text_encoder is None:
|
# 7. Now we can be sure that loading the embedding matrix works
|
||||||
raise ValueError(
|
# < Unsafe code:
|
||||||
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
|
||||||
f" `{self.load_textual_inversion.__name__}`"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove any existing hooks.
|
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
||||||
is_model_cpu_offload = False
|
is_model_cpu_offload = False
|
||||||
is_sequential_cpu_offload = False
|
is_sequential_cpu_offload = False
|
||||||
recursive = False
|
|
||||||
for _, component in self.components.items():
|
for _, component in self.components.items():
|
||||||
if isinstance(component, nn.Module):
|
if isinstance(component, nn.Module):
|
||||||
if hasattr(component, "_hf_hook"):
|
if hasattr(component, "_hf_hook"):
|
||||||
@@ -816,168 +1029,34 @@ class TextualInversionLoaderMixin:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
||||||
)
|
)
|
||||||
recursive = is_sequential_cpu_offload
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||||
remove_hook_from_module(component, recurse=recursive)
|
|
||||||
|
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
# 7.2 save expected device and dtype
|
||||||
force_download = kwargs.pop("force_download", False)
|
device = text_encoder.device
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
dtype = text_encoder.dtype
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
|
||||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
|
||||||
weight_name = kwargs.pop("weight_name", None)
|
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
|
||||||
|
|
||||||
allow_pickle = False
|
# 7.3 Increase token embedding matrix
|
||||||
if use_safetensors is None:
|
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
|
||||||
use_safetensors = True
|
input_embeddings = text_encoder.get_input_embeddings().weight
|
||||||
allow_pickle = True
|
|
||||||
|
|
||||||
user_agent = {
|
|
||||||
"file_type": "text_inversion",
|
|
||||||
"framework": "pytorch",
|
|
||||||
}
|
|
||||||
|
|
||||||
if not isinstance(pretrained_model_name_or_path, list):
|
|
||||||
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
|
||||||
else:
|
|
||||||
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
|
||||||
|
|
||||||
if isinstance(token, str):
|
|
||||||
tokens = [token]
|
|
||||||
elif token is None:
|
|
||||||
tokens = [None] * len(pretrained_model_name_or_paths)
|
|
||||||
else:
|
|
||||||
tokens = token
|
|
||||||
|
|
||||||
if len(pretrained_model_name_or_paths) != len(tokens):
|
|
||||||
raise ValueError(
|
|
||||||
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
|
||||||
f"Make sure both lists have the same length."
|
|
||||||
)
|
|
||||||
|
|
||||||
valid_tokens = [t for t in tokens if t is not None]
|
|
||||||
if len(set(valid_tokens)) < len(valid_tokens):
|
|
||||||
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
|
||||||
|
|
||||||
token_ids_and_embeddings = []
|
|
||||||
|
|
||||||
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
|
||||||
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
|
||||||
# 1. Load textual inversion file
|
|
||||||
model_file = None
|
|
||||||
# Let's first try to load .safetensors weights
|
|
||||||
if (use_safetensors and weight_name is None) or (
|
|
||||||
weight_name is not None and weight_name.endswith(".safetensors")
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
|
||||||
except Exception as e:
|
|
||||||
if not allow_pickle:
|
|
||||||
raise e
|
|
||||||
|
|
||||||
model_file = None
|
|
||||||
|
|
||||||
if model_file is None:
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
weights_name=weight_name or TEXT_INVERSION_NAME,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
use_auth_token=use_auth_token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
)
|
|
||||||
state_dict = torch.load(model_file, map_location="cpu")
|
|
||||||
else:
|
|
||||||
state_dict = pretrained_model_name_or_path
|
|
||||||
|
|
||||||
# 2. Load token and embedding correcly from file
|
|
||||||
loaded_token = None
|
|
||||||
if isinstance(state_dict, torch.Tensor):
|
|
||||||
if token is None:
|
|
||||||
raise ValueError(
|
|
||||||
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
|
||||||
)
|
|
||||||
embedding = state_dict
|
|
||||||
elif len(state_dict) == 1:
|
|
||||||
# diffusers
|
|
||||||
loaded_token, embedding = next(iter(state_dict.items()))
|
|
||||||
elif "string_to_param" in state_dict:
|
|
||||||
# A1111
|
|
||||||
loaded_token = state_dict["name"]
|
|
||||||
embedding = state_dict["string_to_param"]["*"]
|
|
||||||
|
|
||||||
if token is not None and loaded_token != token:
|
|
||||||
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
|
||||||
else:
|
|
||||||
token = loaded_token
|
|
||||||
|
|
||||||
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
|
|
||||||
|
|
||||||
# 3. Make sure we don't mess up the tokenizer or text encoder
|
|
||||||
vocab = tokenizer.get_vocab()
|
|
||||||
if token in vocab:
|
|
||||||
raise ValueError(
|
|
||||||
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
|
||||||
)
|
|
||||||
elif f"{token}_1" in vocab:
|
|
||||||
multi_vector_tokens = [token]
|
|
||||||
i = 1
|
|
||||||
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
|
||||||
multi_vector_tokens.append(f"{token}_{i}")
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
|
||||||
)
|
|
||||||
|
|
||||||
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
|
||||||
|
|
||||||
if is_multi_vector:
|
|
||||||
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
|
||||||
embeddings = [e for e in embedding] # noqa: C416
|
|
||||||
else:
|
|
||||||
tokens = [token]
|
|
||||||
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
|
||||||
|
|
||||||
|
# 7.4 Load token and embedding
|
||||||
|
for token, embedding in zip(tokens, embeddings):
|
||||||
# add tokens and get ids
|
# add tokens and get ids
|
||||||
tokenizer.add_tokens(tokens)
|
tokenizer.add_tokens(token)
|
||||||
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
||||||
token_ids_and_embeddings += zip(token_ids, embeddings)
|
input_embeddings.data[token_id] = embedding
|
||||||
|
|
||||||
logger.info(f"Loaded textual inversion embedding for {token}.")
|
logger.info(f"Loaded textual inversion embedding for {token}.")
|
||||||
|
|
||||||
# resize token embeddings and set all new embeddings
|
input_embeddings.to(dtype=dtype, device=device)
|
||||||
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
||||||
for token_id, embedding in token_ids_and_embeddings:
|
|
||||||
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
|
||||||
|
|
||||||
# offload back
|
# 7.5 Offload the model again
|
||||||
if is_model_cpu_offload:
|
if is_model_cpu_offload:
|
||||||
self.enable_model_cpu_offload()
|
self.enable_model_cpu_offload()
|
||||||
elif is_sequential_cpu_offload:
|
elif is_sequential_cpu_offload:
|
||||||
self.enable_sequential_cpu_offload()
|
self.enable_sequential_cpu_offload()
|
||||||
|
|
||||||
|
# / Unsafe Code >
|
||||||
|
|
||||||
|
|
||||||
class LoraLoaderMixin:
|
class LoraLoaderMixin:
|
||||||
r"""
|
r"""
|
||||||
@@ -1009,26 +1088,21 @@ class LoraLoaderMixin:
|
|||||||
kwargs (`dict`, *optional*):
|
kwargs (`dict`, *optional*):
|
||||||
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||||
"""
|
"""
|
||||||
# Remove any existing hooks.
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
is_model_cpu_offload = False
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||||
is_sequential_cpu_offload = False
|
|
||||||
recurive = False
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||||
for _, component in self.components.items():
|
if not is_correct_format:
|
||||||
if isinstance(component, nn.Module):
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
if hasattr(component, "_hf_hook"):
|
|
||||||
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
|
||||||
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
|
||||||
logger.info(
|
|
||||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
||||||
)
|
|
||||||
recurive = is_sequential_cpu_offload
|
|
||||||
remove_hook_from_module(component, recurse=recurive)
|
|
||||||
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
|
||||||
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
|
||||||
self.load_lora_into_unet(
|
self.load_lora_into_unet(
|
||||||
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
|
state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
unet=self.unet,
|
||||||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
self.load_lora_into_text_encoder(
|
self.load_lora_into_text_encoder(
|
||||||
state_dict,
|
state_dict,
|
||||||
@@ -1036,14 +1110,9 @@ class LoraLoaderMixin:
|
|||||||
text_encoder=self.text_encoder,
|
text_encoder=self.text_encoder,
|
||||||
lora_scale=self.lora_scale,
|
lora_scale=self.lora_scale,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
_pipeline=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Offload back.
|
|
||||||
if is_model_cpu_offload:
|
|
||||||
self.enable_model_cpu_offload()
|
|
||||||
elif is_sequential_cpu_offload:
|
|
||||||
self.enable_sequential_cpu_offload()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def lora_state_dict(
|
def lora_state_dict(
|
||||||
cls,
|
cls,
|
||||||
@@ -1340,7 +1409,7 @@ class LoraLoaderMixin:
|
|||||||
return new_state_dict
|
return new_state_dict
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
|
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `unet`.
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
||||||
|
|
||||||
@@ -1382,13 +1451,22 @@ class LoraLoaderMixin:
|
|||||||
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
||||||
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
||||||
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
||||||
warnings.warn(warn_message)
|
logger.warn(warn_message)
|
||||||
|
|
||||||
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
|
unet.load_attn_procs(
|
||||||
|
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load_lora_into_text_encoder(
|
def load_lora_into_text_encoder(
|
||||||
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
|
cls,
|
||||||
|
state_dict,
|
||||||
|
network_alphas,
|
||||||
|
text_encoder,
|
||||||
|
prefix=None,
|
||||||
|
lora_scale=1.0,
|
||||||
|
low_cpu_mem_usage=None,
|
||||||
|
_pipeline=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||||
@@ -1498,11 +1576,15 @@ class LoraLoaderMixin:
|
|||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
# set correct dtype & device
|
is_pipeline_offloaded = _pipeline is not None and any(
|
||||||
text_encoder_lora_state_dict = {
|
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
||||||
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
)
|
||||||
for k, v in text_encoder_lora_state_dict.items()
|
if is_pipeline_offloaded and low_cpu_mem_usage:
|
||||||
}
|
low_cpu_mem_usage = True
|
||||||
|
logger.info(
|
||||||
|
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
||||||
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
if low_cpu_mem_usage:
|
||||||
device = next(iter(text_encoder_lora_state_dict.values())).device
|
device = next(iter(text_encoder_lora_state_dict.values())).device
|
||||||
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
||||||
@@ -1518,8 +1600,33 @@ class LoraLoaderMixin:
|
|||||||
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# <Unsafe code
|
||||||
|
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
||||||
|
# Now we remove any existing hooks to
|
||||||
|
is_model_cpu_offload = False
|
||||||
|
is_sequential_cpu_offload = False
|
||||||
|
if _pipeline is not None:
|
||||||
|
for _, component in _pipeline.components.items():
|
||||||
|
if isinstance(component, torch.nn.Module):
|
||||||
|
if hasattr(component, "_hf_hook"):
|
||||||
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
||||||
|
is_sequential_cpu_offload = isinstance(
|
||||||
|
getattr(component, "_hf_hook"), AlignDevicesHook
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||||
|
)
|
||||||
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||||
|
|
||||||
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
||||||
|
|
||||||
|
# Offload back.
|
||||||
|
if is_model_cpu_offload:
|
||||||
|
_pipeline.enable_model_cpu_offload()
|
||||||
|
elif is_sequential_cpu_offload:
|
||||||
|
_pipeline.enable_sequential_cpu_offload()
|
||||||
|
# Unsafe code />
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def lora_scale(self) -> float:
|
def lora_scale(self) -> float:
|
||||||
# property function that returns the lora scale which can be set at run time by the pipeline.
|
# property function that returns the lora scale which can be set at run time by the pipeline.
|
||||||
@@ -2098,6 +2205,7 @@ class FromSingleFileMixin:
|
|||||||
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
|
||||||
|
|
||||||
original_config_file = kwargs.pop("original_config_file", None)
|
original_config_file = kwargs.pop("original_config_file", None)
|
||||||
|
config_files = kwargs.pop("config_files", None)
|
||||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
force_download = kwargs.pop("force_download", False)
|
force_download = kwargs.pop("force_download", False)
|
||||||
@@ -2215,6 +2323,7 @@ class FromSingleFileMixin:
|
|||||||
vae=vae,
|
vae=vae,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
original_config_file=original_config_file,
|
original_config_file=original_config_file,
|
||||||
|
config_files=config_files,
|
||||||
)
|
)
|
||||||
|
|
||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
@@ -2556,3 +2665,131 @@ class FromOriginalControlnetMixin:
|
|||||||
controlnet.to(torch_dtype=torch_dtype)
|
controlnet.to(torch_dtype=torch_dtype)
|
||||||
|
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|
||||||
|
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||||
|
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
|
||||||
|
|
||||||
|
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||||
|
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||||
|
"""
|
||||||
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
||||||
|
`self.text_encoder`.
|
||||||
|
|
||||||
|
All kwargs are forwarded to `self.lora_state_dict`.
|
||||||
|
|
||||||
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
||||||
|
|
||||||
|
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
|
||||||
|
`self.unet`.
|
||||||
|
|
||||||
|
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
|
||||||
|
into `self.text_encoder`.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||||
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||||
|
kwargs (`dict`, *optional*):
|
||||||
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
||||||
|
"""
|
||||||
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
||||||
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
||||||
|
# pipeline.
|
||||||
|
|
||||||
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||||
|
state_dict, network_alphas = self.lora_state_dict(
|
||||||
|
pretrained_model_name_or_path_or_dict,
|
||||||
|
unet_config=self.unet.config,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||||
|
if not is_correct_format:
|
||||||
|
raise ValueError("Invalid LoRA checkpoint.")
|
||||||
|
|
||||||
|
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
|
||||||
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||||
|
if len(text_encoder_state_dict) > 0:
|
||||||
|
self.load_lora_into_text_encoder(
|
||||||
|
text_encoder_state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
text_encoder=self.text_encoder,
|
||||||
|
prefix="text_encoder",
|
||||||
|
lora_scale=self.lora_scale,
|
||||||
|
_pipeline=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||||
|
if len(text_encoder_2_state_dict) > 0:
|
||||||
|
self.load_lora_into_text_encoder(
|
||||||
|
text_encoder_2_state_dict,
|
||||||
|
network_alphas=network_alphas,
|
||||||
|
text_encoder=self.text_encoder_2,
|
||||||
|
prefix="text_encoder_2",
|
||||||
|
lora_scale=self.lora_scale,
|
||||||
|
_pipeline=self,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def save_lora_weights(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||||
|
is_main_process: bool = True,
|
||||||
|
weight_name: str = None,
|
||||||
|
save_function: Callable = None,
|
||||||
|
safe_serialization: bool = True,
|
||||||
|
):
|
||||||
|
r"""
|
||||||
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
save_directory (`str` or `os.PathLike`):
|
||||||
|
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
||||||
|
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
|
State dict of the LoRA layers corresponding to the `unet`.
|
||||||
|
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
||||||
|
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
||||||
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
||||||
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
||||||
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
||||||
|
process to avoid race conditions.
|
||||||
|
save_function (`Callable`):
|
||||||
|
The function to use to save the state dictionary. Useful during distributed training when you need to
|
||||||
|
replace `torch.save` with another method. Can be configured with the environment variable
|
||||||
|
`DIFFUSERS_SAVE_MODE`.
|
||||||
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||||
|
"""
|
||||||
|
state_dict = {}
|
||||||
|
|
||||||
|
def pack_weights(layers, prefix):
|
||||||
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||||
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||||
|
return layers_state_dict
|
||||||
|
|
||||||
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
||||||
|
raise ValueError(
|
||||||
|
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
||||||
|
)
|
||||||
|
|
||||||
|
if unet_lora_layers:
|
||||||
|
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||||
|
|
||||||
|
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||||
|
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||||
|
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||||
|
|
||||||
|
self.write_lora_layers(
|
||||||
|
state_dict=state_dict,
|
||||||
|
save_directory=save_directory,
|
||||||
|
is_main_process=is_main_process,
|
||||||
|
weight_name=weight_name,
|
||||||
|
save_function=save_function,
|
||||||
|
safe_serialization=safe_serialization,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _remove_text_encoder_monkey_patch(self):
|
||||||
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||||
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||||
|
|||||||
@@ -90,6 +90,8 @@ class MultiAdapter(ModelMixin):
|
|||||||
features = adapter(x)
|
features = adapter(x)
|
||||||
if accume_state is None:
|
if accume_state is None:
|
||||||
accume_state = features
|
accume_state = features
|
||||||
|
for i in range(len(accume_state)):
|
||||||
|
accume_state[i] = w * accume_state[i]
|
||||||
else:
|
else:
|
||||||
for i in range(len(features)):
|
for i in range(len(features)):
|
||||||
accume_state[i] += w * features[i]
|
accume_state[i] += w * features[i]
|
||||||
|
|||||||
@@ -1255,7 +1255,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
self._all_hooks = []
|
self._all_hooks = []
|
||||||
hook = None
|
hook = None
|
||||||
for model_str in self.model_cpu_offload_seq.split("->"):
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
||||||
model = all_model_components.pop(model_str)
|
model = all_model_components.pop(model_str, None)
|
||||||
if not isinstance(model, torch.nn.Module):
|
if not isinstance(model, torch.nn.Module):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
|
|||||||
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
|
||||||
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
|
||||||
|
config_url = None
|
||||||
|
|
||||||
# model_type = "v1"
|
# model_type = "v1"
|
||||||
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
if config_files is not None and "v1" in config_files:
|
||||||
|
original_config_file = config_files["v1"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
|
||||||
|
|
||||||
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
|
||||||
# model_type = "v2"
|
# model_type = "v2"
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
if config_files is not None and "v2" in config_files:
|
||||||
|
original_config_file = config_files["v2"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
|
||||||
if global_step == 110000:
|
if global_step == 110000:
|
||||||
# v2.1 needs to upcast attention
|
# v2.1 needs to upcast attention
|
||||||
upcast_attention = True
|
upcast_attention = True
|
||||||
elif key_name_sd_xl_base in checkpoint:
|
elif key_name_sd_xl_base in checkpoint:
|
||||||
# only base xl has two text embedders
|
# only base xl has two text embedders
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
if config_files is not None and "xl" in config_files:
|
||||||
|
original_config_file = config_files["xl"]
|
||||||
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
|
||||||
elif key_name_sd_xl_refiner in checkpoint:
|
elif key_name_sd_xl_refiner in checkpoint:
|
||||||
# only refiner xl has embedder and one text embedders
|
# only refiner xl has embedder and one text embedders
|
||||||
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
if config_files is not None and "xl_refiner" in config_files:
|
||||||
|
original_config_file = config_files["xl_refiner"]
|
||||||
original_config_file = BytesIO(requests.get(config_url).content)
|
else:
|
||||||
|
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
|
||||||
|
if config_url is not None:
|
||||||
|
original_config_file = BytesIO(requests.get(config_url).content)
|
||||||
|
|
||||||
original_config = OmegaConf.load(original_config_file)
|
original_config = OmegaConf.load(original_config_file)
|
||||||
|
|
||||||
|
|||||||
@@ -50,13 +50,26 @@ class SafetyConfig(object):
|
|||||||
|
|
||||||
_dummy_objects = {}
|
_dummy_objects = {}
|
||||||
_additional_imports = {}
|
_additional_imports = {}
|
||||||
_import_structure = {
|
_import_structure = {}
|
||||||
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
|
||||||
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
|
||||||
"safety_checker": ["StableDiffusionSafetyChecker"],
|
|
||||||
}
|
|
||||||
_additional_imports.update({"SafetyConfig": SafetyConfig})
|
_additional_imports.update({"SafetyConfig": SafetyConfig})
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not (is_transformers_available() and is_torch_available()):
|
||||||
|
raise OptionalDependencyNotAvailable()
|
||||||
|
except OptionalDependencyNotAvailable:
|
||||||
|
from ...utils import dummy_torch_and_transformers_objects
|
||||||
|
|
||||||
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||||
|
else:
|
||||||
|
_import_structure.update(
|
||||||
|
{
|
||||||
|
"pipeline_output": ["StableDiffusionSafePipelineOutput"],
|
||||||
|
"pipeline_stable_diffusion_safe": ["StableDiffusionPipelineSafe"],
|
||||||
|
"safety_checker": ["StableDiffusionSafetyChecker"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
try:
|
try:
|
||||||
@@ -70,25 +83,16 @@ if TYPE_CHECKING:
|
|||||||
from .safety_checker import SafeStableDiffusionSafetyChecker
|
from .safety_checker import SafeStableDiffusionSafetyChecker
|
||||||
|
|
||||||
else:
|
else:
|
||||||
try:
|
import sys
|
||||||
if not (is_transformers_available() and is_torch_available()):
|
|
||||||
raise OptionalDependencyNotAvailable()
|
|
||||||
except OptionalDependencyNotAvailable:
|
|
||||||
from ...utils import dummy_torch_and_transformers_objects
|
|
||||||
|
|
||||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
sys.modules[__name__] = _LazyModule(
|
||||||
|
__name__,
|
||||||
|
globals()["__file__"],
|
||||||
|
_import_structure,
|
||||||
|
module_spec=__spec__,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
for name, value in _dummy_objects.items():
|
||||||
import sys
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
for name, value in _additional_imports.items():
|
||||||
sys.modules[__name__] = _LazyModule(
|
setattr(sys.modules[__name__], name, value)
|
||||||
__name__,
|
|
||||||
globals()["__file__"],
|
|
||||||
_import_structure,
|
|
||||||
module_spec=__spec__,
|
|
||||||
)
|
|
||||||
|
|
||||||
for name, value in _dummy_objects.items():
|
|
||||||
setattr(sys.modules[__name__], name, value)
|
|
||||||
for name, value in _additional_imports.items():
|
|
||||||
setattr(sys.modules[__name__], name, value)
|
|
||||||
|
|||||||
@@ -162,7 +162,6 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|||||||
scheduler: KarrasDiffusionSchedulers,
|
scheduler: KarrasDiffusionSchedulers,
|
||||||
safety_checker: StableDiffusionSafetyChecker,
|
safety_checker: StableDiffusionSafetyChecker,
|
||||||
feature_extractor: CLIPFeatureExtractor,
|
feature_extractor: CLIPFeatureExtractor,
|
||||||
adapter_weights: Optional[List[float]] = None,
|
|
||||||
requires_safety_checker: bool = True,
|
requires_safety_checker: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -184,7 +183,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(adapter, (list, tuple)):
|
if isinstance(adapter, (list, tuple)):
|
||||||
adapter = MultiAdapter(adapter, adapter_weights=adapter_weights)
|
adapter = MultiAdapter(adapter)
|
||||||
|
|
||||||
self.register_modules(
|
self.register_modules(
|
||||||
vae=vae,
|
vae=vae,
|
||||||
@@ -727,9 +726,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||||
|
|
||||||
# 7. Denoising loop
|
# 7. Denoising loop
|
||||||
adapter_state = self.adapter(adapter_input)
|
if isinstance(self.adapter, MultiAdapter):
|
||||||
for k, v in enumerate(adapter_state):
|
adapter_state = self.adapter(adapter_input, adapter_conditioning_scale)
|
||||||
adapter_state[k] = v * adapter_conditioning_scale
|
for k, v in enumerate(adapter_state):
|
||||||
|
adapter_state[k] = v
|
||||||
|
else:
|
||||||
|
adapter_state = self.adapter(adapter_input)
|
||||||
|
for k, v in enumerate(adapter_state):
|
||||||
|
adapter_state[k] = v * adapter_conditioning_scale
|
||||||
if num_images_per_prompt > 1:
|
if num_images_per_prompt > 1:
|
||||||
for k, v in enumerate(adapter_state):
|
for k, v in enumerate(adapter_state):
|
||||||
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
|
||||||
|
|||||||
@@ -47,3 +47,5 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -51,3 +51,6 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ if TYPE_CHECKING:
|
|||||||
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
|
from .pipeline_wuerstchen import WuerstchenDecoderPipeline
|
||||||
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
|
from .pipeline_wuerstchen_combined import WuerstchenCombinedPipeline
|
||||||
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
|
from .pipeline_wuerstchen_prior import WuerstchenPriorPipeline
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
@@ -51,3 +50,6 @@ else:
|
|||||||
_import_structure,
|
_import_structure,
|
||||||
module_spec=__spec__,
|
module_spec=__spec__,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
for name, value in _dummy_objects.items():
|
||||||
|
setattr(sys.modules[__name__], name, value)
|
||||||
|
|||||||
@@ -216,7 +216,9 @@ class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterM
|
|||||||
return super().get_dummy_components("multi_adapter")
|
return super().get_dummy_components("multi_adapter")
|
||||||
|
|
||||||
def get_dummy_inputs(self, device, seed=0):
|
def get_dummy_inputs(self, device, seed=0):
|
||||||
return super().get_dummy_inputs(device, seed, num_images=2)
|
inputs = super().get_dummy_inputs(device, seed, num_images=2)
|
||||||
|
inputs["adapter_conditioning_scale"] = [0.5, 0.5]
|
||||||
|
return inputs
|
||||||
|
|
||||||
def test_stable_diffusion_adapter_default_case(self):
|
def test_stable_diffusion_adapter_default_case(self):
|
||||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||||
|
|||||||
Reference in New Issue
Block a user