|
|
|
|
@@ -13,7 +13,6 @@
|
|
|
|
|
# limitations under the License.
|
|
|
|
|
import os
|
|
|
|
|
import re
|
|
|
|
|
import warnings
|
|
|
|
|
from collections import defaultdict
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from io import BytesIO
|
|
|
|
|
@@ -307,6 +306,9 @@ class UNet2DConditionLoadersMixin:
|
|
|
|
|
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
|
|
|
|
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
|
|
|
|
network_alphas = kwargs.pop("network_alphas", None)
|
|
|
|
|
|
|
|
|
|
_pipeline = kwargs.pop("_pipeline", None)
|
|
|
|
|
|
|
|
|
|
is_network_alphas_none = network_alphas is None
|
|
|
|
|
|
|
|
|
|
allow_pickle = False
|
|
|
|
|
@@ -460,6 +462,7 @@ class UNet2DConditionLoadersMixin:
|
|
|
|
|
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
|
|
|
|
else:
|
|
|
|
|
lora.load_state_dict(value_dict)
|
|
|
|
|
|
|
|
|
|
elif is_custom_diffusion:
|
|
|
|
|
attn_processors = {}
|
|
|
|
|
custom_diffusion_grouped_dict = defaultdict(dict)
|
|
|
|
|
@@ -489,19 +492,44 @@ class UNet2DConditionLoadersMixin:
|
|
|
|
|
cross_attention_dim=cross_attention_dim,
|
|
|
|
|
)
|
|
|
|
|
attn_processors[key].load_state_dict(value_dict)
|
|
|
|
|
|
|
|
|
|
self.set_attn_processor(attn_processors)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# <Unsafe code
|
|
|
|
|
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
|
|
|
|
# Now we remove any existing hooks to
|
|
|
|
|
is_model_cpu_offload = False
|
|
|
|
|
is_sequential_cpu_offload = False
|
|
|
|
|
if _pipeline is not None:
|
|
|
|
|
for _, component in _pipeline.components.items():
|
|
|
|
|
if isinstance(component, nn.Module):
|
|
|
|
|
if hasattr(component, "_hf_hook"):
|
|
|
|
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
|
|
|
|
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
|
|
|
)
|
|
|
|
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
|
|
|
|
|
|
|
|
|
# only custom diffusion needs to set attn processors
|
|
|
|
|
if is_custom_diffusion:
|
|
|
|
|
self.set_attn_processor(attn_processors)
|
|
|
|
|
|
|
|
|
|
# set lora layers
|
|
|
|
|
for target_module, lora_layer in lora_layers_list:
|
|
|
|
|
target_module.set_lora_layer(lora_layer)
|
|
|
|
|
|
|
|
|
|
self.to(dtype=self.dtype, device=self.device)
|
|
|
|
|
|
|
|
|
|
# Offload back.
|
|
|
|
|
if is_model_cpu_offload:
|
|
|
|
|
_pipeline.enable_model_cpu_offload()
|
|
|
|
|
elif is_sequential_cpu_offload:
|
|
|
|
|
_pipeline.enable_sequential_cpu_offload()
|
|
|
|
|
# Unsafe code />
|
|
|
|
|
|
|
|
|
|
def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
|
|
|
|
|
is_new_lora_format = all(
|
|
|
|
|
key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
|
|
|
|
|
@@ -622,6 +650,81 @@ class UNet2DConditionLoadersMixin:
|
|
|
|
|
module._unfuse_lora()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
|
|
|
|
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
|
|
|
|
force_download = kwargs.pop("force_download", False)
|
|
|
|
|
resume_download = kwargs.pop("resume_download", False)
|
|
|
|
|
proxies = kwargs.pop("proxies", None)
|
|
|
|
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
|
|
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
|
|
|
revision = kwargs.pop("revision", None)
|
|
|
|
|
subfolder = kwargs.pop("subfolder", None)
|
|
|
|
|
weight_name = kwargs.pop("weight_name", None)
|
|
|
|
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
|
|
|
|
|
|
|
|
|
allow_pickle = False
|
|
|
|
|
if use_safetensors is None:
|
|
|
|
|
use_safetensors = True
|
|
|
|
|
allow_pickle = True
|
|
|
|
|
|
|
|
|
|
user_agent = {
|
|
|
|
|
"file_type": "text_inversion",
|
|
|
|
|
"framework": "pytorch",
|
|
|
|
|
}
|
|
|
|
|
state_dicts = []
|
|
|
|
|
for pretrained_model_name_or_path in pretrained_model_name_or_paths:
|
|
|
|
|
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
|
|
|
|
# 3.1. Load textual inversion file
|
|
|
|
|
model_file = None
|
|
|
|
|
|
|
|
|
|
# Let's first try to load .safetensors weights
|
|
|
|
|
if (use_safetensors and weight_name is None) or (
|
|
|
|
|
weight_name is not None and weight_name.endswith(".safetensors")
|
|
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
model_file = _get_model_file(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
subfolder=subfolder,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if not allow_pickle:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
model_file = None
|
|
|
|
|
|
|
|
|
|
if model_file is None:
|
|
|
|
|
model_file = _get_model_file(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
weights_name=weight_name or TEXT_INVERSION_NAME,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
subfolder=subfolder,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
|
|
|
else:
|
|
|
|
|
state_dict = pretrained_model_name_or_path
|
|
|
|
|
|
|
|
|
|
state_dicts.append(state_dict)
|
|
|
|
|
|
|
|
|
|
return state_dicts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextualInversionLoaderMixin:
|
|
|
|
|
r"""
|
|
|
|
|
Load textual inversion tokens and embeddings to the tokenizer and text encoder.
|
|
|
|
|
@@ -684,6 +787,97 @@ class TextualInversionLoaderMixin:
|
|
|
|
|
|
|
|
|
|
return prompt
|
|
|
|
|
|
|
|
|
|
def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
|
|
|
|
|
if tokenizer is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
|
|
|
|
f" `{self.load_textual_inversion.__name__}`"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if text_encoder is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
|
|
|
|
f" `{self.load_textual_inversion.__name__}`"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if len(pretrained_model_name_or_paths) != len(tokens):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
|
|
|
|
|
f"Make sure both lists have the same length."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
valid_tokens = [t for t in tokens if t is not None]
|
|
|
|
|
if len(set(valid_tokens)) < len(valid_tokens):
|
|
|
|
|
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
|
|
|
|
|
all_tokens = []
|
|
|
|
|
all_embeddings = []
|
|
|
|
|
for state_dict, token in zip(state_dicts, tokens):
|
|
|
|
|
if isinstance(state_dict, torch.Tensor):
|
|
|
|
|
if token is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
|
|
|
|
)
|
|
|
|
|
loaded_token = token
|
|
|
|
|
embedding = state_dict
|
|
|
|
|
elif len(state_dict) == 1:
|
|
|
|
|
# diffusers
|
|
|
|
|
loaded_token, embedding = next(iter(state_dict.items()))
|
|
|
|
|
elif "string_to_param" in state_dict:
|
|
|
|
|
# A1111
|
|
|
|
|
loaded_token = state_dict["name"]
|
|
|
|
|
embedding = state_dict["string_to_param"]["*"]
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
|
|
|
|
|
"Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
|
|
|
|
|
" input key."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if token is not None and loaded_token != token:
|
|
|
|
|
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
|
|
|
|
else:
|
|
|
|
|
token = loaded_token
|
|
|
|
|
|
|
|
|
|
if token in tokenizer.get_vocab():
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
all_tokens.append(token)
|
|
|
|
|
all_embeddings.append(embedding)
|
|
|
|
|
|
|
|
|
|
return all_tokens, all_embeddings
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
|
|
|
|
|
all_tokens = []
|
|
|
|
|
all_embeddings = []
|
|
|
|
|
|
|
|
|
|
for embedding, token in zip(embeddings, tokens):
|
|
|
|
|
if f"{token}_1" in tokenizer.get_vocab():
|
|
|
|
|
multi_vector_tokens = [token]
|
|
|
|
|
i = 1
|
|
|
|
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
|
|
|
|
multi_vector_tokens.append(f"{token}_{i}")
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
|
|
|
|
if is_multi_vector:
|
|
|
|
|
all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
|
|
|
|
all_embeddings += [e for e in embedding] # noqa: C416
|
|
|
|
|
else:
|
|
|
|
|
all_tokens += [token]
|
|
|
|
|
all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
|
|
|
|
|
|
|
|
|
return all_tokens, all_embeddings
|
|
|
|
|
|
|
|
|
|
def load_textual_inversion(
|
|
|
|
|
self,
|
|
|
|
|
pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
|
|
|
|
|
@@ -789,25 +983,44 @@ class TextualInversionLoaderMixin:
|
|
|
|
|
```
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
# 1. Set correct tokenizer and text encoder
|
|
|
|
|
tokenizer = tokenizer or getattr(self, "tokenizer", None)
|
|
|
|
|
text_encoder = text_encoder or getattr(self, "text_encoder", None)
|
|
|
|
|
|
|
|
|
|
if tokenizer is None:
|
|
|
|
|
# 2. Normalize inputs
|
|
|
|
|
pretrained_model_name_or_paths = (
|
|
|
|
|
[pretrained_model_name_or_path]
|
|
|
|
|
if not isinstance(pretrained_model_name_or_path, list)
|
|
|
|
|
else pretrained_model_name_or_path
|
|
|
|
|
)
|
|
|
|
|
tokens = len(pretrained_model_name_or_paths) * [token] if (isinstance(token, str) or token is None) else token
|
|
|
|
|
|
|
|
|
|
# 3. Check inputs
|
|
|
|
|
self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
|
|
|
|
|
|
|
|
|
|
# 4. Load state dicts of textual embeddings
|
|
|
|
|
state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
|
|
|
|
|
|
|
|
|
|
# 4. Retrieve tokens and embeddings
|
|
|
|
|
tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
|
|
|
|
|
|
|
|
|
|
# 5. Extend tokens and embeddings for multi vector
|
|
|
|
|
tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
|
|
|
|
|
|
|
|
|
|
# 6. Make sure all embeddings have the correct size
|
|
|
|
|
expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
|
|
|
|
|
if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
|
|
|
|
|
f" `{self.load_textual_inversion.__name__}`"
|
|
|
|
|
"Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
|
|
|
|
|
"to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if text_encoder is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
|
|
|
|
|
f" `{self.load_textual_inversion.__name__}`"
|
|
|
|
|
)
|
|
|
|
|
# 7. Now we can be sure that loading the embedding matrix works
|
|
|
|
|
# < Unsafe code:
|
|
|
|
|
|
|
|
|
|
# Remove any existing hooks.
|
|
|
|
|
# 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
|
|
|
|
|
is_model_cpu_offload = False
|
|
|
|
|
is_sequential_cpu_offload = False
|
|
|
|
|
recursive = False
|
|
|
|
|
for _, component in self.components.items():
|
|
|
|
|
if isinstance(component, nn.Module):
|
|
|
|
|
if hasattr(component, "_hf_hook"):
|
|
|
|
|
@@ -816,168 +1029,34 @@ class TextualInversionLoaderMixin:
|
|
|
|
|
logger.info(
|
|
|
|
|
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
|
|
|
|
)
|
|
|
|
|
recursive = is_sequential_cpu_offload
|
|
|
|
|
remove_hook_from_module(component, recurse=recursive)
|
|
|
|
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
|
|
|
|
|
|
|
|
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
|
|
|
|
force_download = kwargs.pop("force_download", False)
|
|
|
|
|
resume_download = kwargs.pop("resume_download", False)
|
|
|
|
|
proxies = kwargs.pop("proxies", None)
|
|
|
|
|
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
|
|
|
|
use_auth_token = kwargs.pop("use_auth_token", None)
|
|
|
|
|
revision = kwargs.pop("revision", None)
|
|
|
|
|
subfolder = kwargs.pop("subfolder", None)
|
|
|
|
|
weight_name = kwargs.pop("weight_name", None)
|
|
|
|
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
|
|
|
|
# 7.2 save expected device and dtype
|
|
|
|
|
device = text_encoder.device
|
|
|
|
|
dtype = text_encoder.dtype
|
|
|
|
|
|
|
|
|
|
allow_pickle = False
|
|
|
|
|
if use_safetensors is None:
|
|
|
|
|
use_safetensors = True
|
|
|
|
|
allow_pickle = True
|
|
|
|
|
|
|
|
|
|
user_agent = {
|
|
|
|
|
"file_type": "text_inversion",
|
|
|
|
|
"framework": "pytorch",
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if not isinstance(pretrained_model_name_or_path, list):
|
|
|
|
|
pretrained_model_name_or_paths = [pretrained_model_name_or_path]
|
|
|
|
|
else:
|
|
|
|
|
pretrained_model_name_or_paths = pretrained_model_name_or_path
|
|
|
|
|
|
|
|
|
|
if isinstance(token, str):
|
|
|
|
|
tokens = [token]
|
|
|
|
|
elif token is None:
|
|
|
|
|
tokens = [None] * len(pretrained_model_name_or_paths)
|
|
|
|
|
else:
|
|
|
|
|
tokens = token
|
|
|
|
|
|
|
|
|
|
if len(pretrained_model_name_or_paths) != len(tokens):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)}"
|
|
|
|
|
f"Make sure both lists have the same length."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
valid_tokens = [t for t in tokens if t is not None]
|
|
|
|
|
if len(set(valid_tokens)) < len(valid_tokens):
|
|
|
|
|
raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
|
|
|
|
|
|
|
|
|
|
token_ids_and_embeddings = []
|
|
|
|
|
|
|
|
|
|
for pretrained_model_name_or_path, token in zip(pretrained_model_name_or_paths, tokens):
|
|
|
|
|
if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
|
|
|
|
|
# 1. Load textual inversion file
|
|
|
|
|
model_file = None
|
|
|
|
|
# Let's first try to load .safetensors weights
|
|
|
|
|
if (use_safetensors and weight_name is None) or (
|
|
|
|
|
weight_name is not None and weight_name.endswith(".safetensors")
|
|
|
|
|
):
|
|
|
|
|
try:
|
|
|
|
|
model_file = _get_model_file(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
subfolder=subfolder,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
state_dict = safetensors.torch.load_file(model_file, device="cpu")
|
|
|
|
|
except Exception as e:
|
|
|
|
|
if not allow_pickle:
|
|
|
|
|
raise e
|
|
|
|
|
|
|
|
|
|
model_file = None
|
|
|
|
|
|
|
|
|
|
if model_file is None:
|
|
|
|
|
model_file = _get_model_file(
|
|
|
|
|
pretrained_model_name_or_path,
|
|
|
|
|
weights_name=weight_name or TEXT_INVERSION_NAME,
|
|
|
|
|
cache_dir=cache_dir,
|
|
|
|
|
force_download=force_download,
|
|
|
|
|
resume_download=resume_download,
|
|
|
|
|
proxies=proxies,
|
|
|
|
|
local_files_only=local_files_only,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
revision=revision,
|
|
|
|
|
subfolder=subfolder,
|
|
|
|
|
user_agent=user_agent,
|
|
|
|
|
)
|
|
|
|
|
state_dict = torch.load(model_file, map_location="cpu")
|
|
|
|
|
else:
|
|
|
|
|
state_dict = pretrained_model_name_or_path
|
|
|
|
|
|
|
|
|
|
# 2. Load token and embedding correcly from file
|
|
|
|
|
loaded_token = None
|
|
|
|
|
if isinstance(state_dict, torch.Tensor):
|
|
|
|
|
if token is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
|
|
|
|
|
)
|
|
|
|
|
embedding = state_dict
|
|
|
|
|
elif len(state_dict) == 1:
|
|
|
|
|
# diffusers
|
|
|
|
|
loaded_token, embedding = next(iter(state_dict.items()))
|
|
|
|
|
elif "string_to_param" in state_dict:
|
|
|
|
|
# A1111
|
|
|
|
|
loaded_token = state_dict["name"]
|
|
|
|
|
embedding = state_dict["string_to_param"]["*"]
|
|
|
|
|
|
|
|
|
|
if token is not None and loaded_token != token:
|
|
|
|
|
logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
|
|
|
|
|
else:
|
|
|
|
|
token = loaded_token
|
|
|
|
|
|
|
|
|
|
embedding = embedding.to(dtype=text_encoder.dtype, device=text_encoder.device)
|
|
|
|
|
|
|
|
|
|
# 3. Make sure we don't mess up the tokenizer or text encoder
|
|
|
|
|
vocab = tokenizer.get_vocab()
|
|
|
|
|
if token in vocab:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
|
|
|
|
|
)
|
|
|
|
|
elif f"{token}_1" in vocab:
|
|
|
|
|
multi_vector_tokens = [token]
|
|
|
|
|
i = 1
|
|
|
|
|
while f"{token}_{i}" in tokenizer.added_tokens_encoder:
|
|
|
|
|
multi_vector_tokens.append(f"{token}_{i}")
|
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
|
|
|
|
|
|
|
|
|
|
if is_multi_vector:
|
|
|
|
|
tokens = [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
|
|
|
|
|
embeddings = [e for e in embedding] # noqa: C416
|
|
|
|
|
else:
|
|
|
|
|
tokens = [token]
|
|
|
|
|
embeddings = [embedding[0]] if len(embedding.shape) > 1 else [embedding]
|
|
|
|
|
# 7.3 Increase token embedding matrix
|
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
|
|
|
|
|
input_embeddings = text_encoder.get_input_embeddings().weight
|
|
|
|
|
|
|
|
|
|
# 7.4 Load token and embedding
|
|
|
|
|
for token, embedding in zip(tokens, embeddings):
|
|
|
|
|
# add tokens and get ids
|
|
|
|
|
tokenizer.add_tokens(tokens)
|
|
|
|
|
token_ids = tokenizer.convert_tokens_to_ids(tokens)
|
|
|
|
|
token_ids_and_embeddings += zip(token_ids, embeddings)
|
|
|
|
|
|
|
|
|
|
tokenizer.add_tokens(token)
|
|
|
|
|
token_id = tokenizer.convert_tokens_to_ids(token)
|
|
|
|
|
input_embeddings.data[token_id] = embedding
|
|
|
|
|
logger.info(f"Loaded textual inversion embedding for {token}.")
|
|
|
|
|
|
|
|
|
|
# resize token embeddings and set all new embeddings
|
|
|
|
|
text_encoder.resize_token_embeddings(len(tokenizer))
|
|
|
|
|
for token_id, embedding in token_ids_and_embeddings:
|
|
|
|
|
text_encoder.get_input_embeddings().weight.data[token_id] = embedding
|
|
|
|
|
input_embeddings.to(dtype=dtype, device=device)
|
|
|
|
|
|
|
|
|
|
# offload back
|
|
|
|
|
# 7.5 Offload the model again
|
|
|
|
|
if is_model_cpu_offload:
|
|
|
|
|
self.enable_model_cpu_offload()
|
|
|
|
|
elif is_sequential_cpu_offload:
|
|
|
|
|
self.enable_sequential_cpu_offload()
|
|
|
|
|
|
|
|
|
|
# / Unsafe Code >
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoraLoaderMixin:
|
|
|
|
|
r"""
|
|
|
|
|
@@ -1009,26 +1088,21 @@ class LoraLoaderMixin:
|
|
|
|
|
kwargs (`dict`, *optional*):
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
|
|
|
|
"""
|
|
|
|
|
# Remove any existing hooks.
|
|
|
|
|
is_model_cpu_offload = False
|
|
|
|
|
is_sequential_cpu_offload = False
|
|
|
|
|
recurive = False
|
|
|
|
|
for _, component in self.components.items():
|
|
|
|
|
if isinstance(component, nn.Module):
|
|
|
|
|
if hasattr(component, "_hf_hook"):
|
|
|
|
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
|
|
|
|
is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
|
|
|
)
|
|
|
|
|
recurive = is_sequential_cpu_offload
|
|
|
|
|
remove_hook_from_module(component, recurse=recurive)
|
|
|
|
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
|
|
|
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
|
|
|
|
|
|
|
|
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
|
|
|
|
if not is_correct_format:
|
|
|
|
|
raise ValueError("Invalid LoRA checkpoint.")
|
|
|
|
|
|
|
|
|
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
|
|
|
|
|
|
|
|
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
|
|
|
|
self.load_lora_into_unet(
|
|
|
|
|
state_dict, network_alphas=network_alphas, unet=self.unet, low_cpu_mem_usage=low_cpu_mem_usage
|
|
|
|
|
state_dict,
|
|
|
|
|
network_alphas=network_alphas,
|
|
|
|
|
unet=self.unet,
|
|
|
|
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
|
|
|
_pipeline=self,
|
|
|
|
|
)
|
|
|
|
|
self.load_lora_into_text_encoder(
|
|
|
|
|
state_dict,
|
|
|
|
|
@@ -1036,14 +1110,9 @@ class LoraLoaderMixin:
|
|
|
|
|
text_encoder=self.text_encoder,
|
|
|
|
|
lora_scale=self.lora_scale,
|
|
|
|
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
|
|
|
_pipeline=self,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Offload back.
|
|
|
|
|
if is_model_cpu_offload:
|
|
|
|
|
self.enable_model_cpu_offload()
|
|
|
|
|
elif is_sequential_cpu_offload:
|
|
|
|
|
self.enable_sequential_cpu_offload()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def lora_state_dict(
|
|
|
|
|
cls,
|
|
|
|
|
@@ -1340,7 +1409,7 @@ class LoraLoaderMixin:
|
|
|
|
|
return new_state_dict
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None):
|
|
|
|
|
def load_lora_into_unet(cls, state_dict, network_alphas, unet, low_cpu_mem_usage=None, _pipeline=None):
|
|
|
|
|
"""
|
|
|
|
|
This will load the LoRA layers specified in `state_dict` into `unet`.
|
|
|
|
|
|
|
|
|
|
@@ -1382,13 +1451,22 @@ class LoraLoaderMixin:
|
|
|
|
|
# Otherwise, we're dealing with the old format. This means the `state_dict` should only
|
|
|
|
|
# contain the module names of the `unet` as its keys WITHOUT any prefix.
|
|
|
|
|
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet.{module_name}': params for module_name, params in old_state_dict.items()}`."
|
|
|
|
|
warnings.warn(warn_message)
|
|
|
|
|
logger.warn(warn_message)
|
|
|
|
|
|
|
|
|
|
unet.load_attn_procs(state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage)
|
|
|
|
|
unet.load_attn_procs(
|
|
|
|
|
state_dict, network_alphas=network_alphas, low_cpu_mem_usage=low_cpu_mem_usage, _pipeline=_pipeline
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def load_lora_into_text_encoder(
|
|
|
|
|
cls, state_dict, network_alphas, text_encoder, prefix=None, lora_scale=1.0, low_cpu_mem_usage=None
|
|
|
|
|
cls,
|
|
|
|
|
state_dict,
|
|
|
|
|
network_alphas,
|
|
|
|
|
text_encoder,
|
|
|
|
|
prefix=None,
|
|
|
|
|
lora_scale=1.0,
|
|
|
|
|
low_cpu_mem_usage=None,
|
|
|
|
|
_pipeline=None,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
|
|
|
|
@@ -1498,11 +1576,15 @@ class LoraLoaderMixin:
|
|
|
|
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# set correct dtype & device
|
|
|
|
|
text_encoder_lora_state_dict = {
|
|
|
|
|
k: v.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
|
|
|
|
for k, v in text_encoder_lora_state_dict.items()
|
|
|
|
|
}
|
|
|
|
|
is_pipeline_offloaded = _pipeline is not None and any(
|
|
|
|
|
isinstance(c, torch.nn.Module) and hasattr(c, "_hf_hook") for c in _pipeline.components.values()
|
|
|
|
|
)
|
|
|
|
|
if is_pipeline_offloaded and low_cpu_mem_usage:
|
|
|
|
|
low_cpu_mem_usage = True
|
|
|
|
|
logger.info(
|
|
|
|
|
f"Pipeline {_pipeline.__class__} is offloaded. Therefore low cpu mem usage loading is forced."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if low_cpu_mem_usage:
|
|
|
|
|
device = next(iter(text_encoder_lora_state_dict.values())).device
|
|
|
|
|
dtype = next(iter(text_encoder_lora_state_dict.values())).dtype
|
|
|
|
|
@@ -1518,8 +1600,33 @@ class LoraLoaderMixin:
|
|
|
|
|
f"failed to load text encoder state dict, unexpected keys: {load_state_dict_results.unexpected_keys}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# <Unsafe code
|
|
|
|
|
# We can be sure that the following works as all we do is change the dtype and device of the text encoder
|
|
|
|
|
# Now we remove any existing hooks to
|
|
|
|
|
is_model_cpu_offload = False
|
|
|
|
|
is_sequential_cpu_offload = False
|
|
|
|
|
if _pipeline is not None:
|
|
|
|
|
for _, component in _pipeline.components.items():
|
|
|
|
|
if isinstance(component, torch.nn.Module):
|
|
|
|
|
if hasattr(component, "_hf_hook"):
|
|
|
|
|
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
|
|
|
|
is_sequential_cpu_offload = isinstance(
|
|
|
|
|
getattr(component, "_hf_hook"), AlignDevicesHook
|
|
|
|
|
)
|
|
|
|
|
logger.info(
|
|
|
|
|
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
|
|
|
|
)
|
|
|
|
|
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
|
|
|
|
|
|
|
|
|
text_encoder.to(device=text_encoder.device, dtype=text_encoder.dtype)
|
|
|
|
|
|
|
|
|
|
# Offload back.
|
|
|
|
|
if is_model_cpu_offload:
|
|
|
|
|
_pipeline.enable_model_cpu_offload()
|
|
|
|
|
elif is_sequential_cpu_offload:
|
|
|
|
|
_pipeline.enable_sequential_cpu_offload()
|
|
|
|
|
# Unsafe code />
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def lora_scale(self) -> float:
|
|
|
|
|
# property function that returns the lora scale which can be set at run time by the pipeline.
|
|
|
|
|
@@ -2558,3 +2665,131 @@ class FromOriginalControlnetMixin:
|
|
|
|
|
controlnet.to(torch_dtype=torch_dtype)
|
|
|
|
|
|
|
|
|
|
return controlnet
|
|
|
|
|
|
|
|
|
|
class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
|
|
|
|
"""This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL"""
|
|
|
|
|
|
|
|
|
|
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
|
|
|
|
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
|
|
|
|
"""
|
|
|
|
|
Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.unet` and
|
|
|
|
|
`self.text_encoder`.
|
|
|
|
|
|
|
|
|
|
All kwargs are forwarded to `self.lora_state_dict`.
|
|
|
|
|
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
|
|
|
|
|
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.load_lora_into_unet`] for more details on how the state dict is loaded into
|
|
|
|
|
`self.unet`.
|
|
|
|
|
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.load_lora_into_text_encoder`] for more details on how the state dict is loaded
|
|
|
|
|
into `self.text_encoder`.
|
|
|
|
|
|
|
|
|
|
Parameters:
|
|
|
|
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
|
|
|
|
kwargs (`dict`, *optional*):
|
|
|
|
|
See [`~loaders.LoraLoaderMixin.lora_state_dict`].
|
|
|
|
|
"""
|
|
|
|
|
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
|
|
|
|
# it here explicitly to be able to tell that it's coming from an SDXL
|
|
|
|
|
# pipeline.
|
|
|
|
|
|
|
|
|
|
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
|
|
|
|
state_dict, network_alphas = self.lora_state_dict(
|
|
|
|
|
pretrained_model_name_or_path_or_dict,
|
|
|
|
|
unet_config=self.unet.config,
|
|
|
|
|
**kwargs,
|
|
|
|
|
)
|
|
|
|
|
is_correct_format = all("lora" in key for key in state_dict.keys())
|
|
|
|
|
if not is_correct_format:
|
|
|
|
|
raise ValueError("Invalid LoRA checkpoint.")
|
|
|
|
|
|
|
|
|
|
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet, _pipeline=self)
|
|
|
|
|
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
|
|
|
|
if len(text_encoder_state_dict) > 0:
|
|
|
|
|
self.load_lora_into_text_encoder(
|
|
|
|
|
text_encoder_state_dict,
|
|
|
|
|
network_alphas=network_alphas,
|
|
|
|
|
text_encoder=self.text_encoder,
|
|
|
|
|
prefix="text_encoder",
|
|
|
|
|
lora_scale=self.lora_scale,
|
|
|
|
|
_pipeline=self,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
|
|
|
|
if len(text_encoder_2_state_dict) > 0:
|
|
|
|
|
self.load_lora_into_text_encoder(
|
|
|
|
|
text_encoder_2_state_dict,
|
|
|
|
|
network_alphas=network_alphas,
|
|
|
|
|
text_encoder=self.text_encoder_2,
|
|
|
|
|
prefix="text_encoder_2",
|
|
|
|
|
lora_scale=self.lora_scale,
|
|
|
|
|
_pipeline=self,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def save_lora_weights(
|
|
|
|
|
self,
|
|
|
|
|
save_directory: Union[str, os.PathLike],
|
|
|
|
|
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
|
|
|
|
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
|
|
|
|
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
|
|
|
|
is_main_process: bool = True,
|
|
|
|
|
weight_name: str = None,
|
|
|
|
|
save_function: Callable = None,
|
|
|
|
|
safe_serialization: bool = True,
|
|
|
|
|
):
|
|
|
|
|
r"""
|
|
|
|
|
Save the LoRA parameters corresponding to the UNet and text encoder.
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
save_directory (`str` or `os.PathLike`):
|
|
|
|
|
Directory to save LoRA parameters to. Will be created if it doesn't exist.
|
|
|
|
|
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
|
|
|
|
State dict of the LoRA layers corresponding to the `unet`.
|
|
|
|
|
text_encoder_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
|
|
|
|
|
State dict of the LoRA layers corresponding to the `text_encoder`. Must explicitly pass the text
|
|
|
|
|
encoder LoRA state dict because it comes from 🤗 Transformers.
|
|
|
|
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
|
|
|
|
Whether the process calling this is the main process or not. Useful during distributed training and you
|
|
|
|
|
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
|
|
|
|
process to avoid race conditions.
|
|
|
|
|
save_function (`Callable`):
|
|
|
|
|
The function to use to save the state dictionary. Useful during distributed training when you need to
|
|
|
|
|
replace `torch.save` with another method. Can be configured with the environment variable
|
|
|
|
|
`DIFFUSERS_SAVE_MODE`.
|
|
|
|
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
|
|
|
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
|
|
|
|
"""
|
|
|
|
|
state_dict = {}
|
|
|
|
|
|
|
|
|
|
def pack_weights(layers, prefix):
|
|
|
|
|
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
|
|
|
|
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
|
|
|
|
return layers_state_dict
|
|
|
|
|
|
|
|
|
|
if not (unet_lora_layers or text_encoder_lora_layers or text_encoder_2_lora_layers):
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You must pass at least one of `unet_lora_layers`, `text_encoder_lora_layers` or `text_encoder_2_lora_layers`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if unet_lora_layers:
|
|
|
|
|
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
|
|
|
|
|
|
|
|
|
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
|
|
|
|
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
|
|
|
|
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
|
|
|
|
|
|
|
|
|
self.write_lora_layers(
|
|
|
|
|
state_dict=state_dict,
|
|
|
|
|
save_directory=save_directory,
|
|
|
|
|
is_main_process=is_main_process,
|
|
|
|
|
weight_name=weight_name,
|
|
|
|
|
save_function=save_function,
|
|
|
|
|
safe_serialization=safe_serialization,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def _remove_text_encoder_monkey_patch(self):
|
|
|
|
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
|
|
|
|
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
|
|
|
|
|