mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
12 Commits
v0.27.0
...
support-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abf24f116e | ||
|
|
da441cefc6 | ||
|
|
4d2ca28e24 | ||
|
|
d7d757a387 | ||
|
|
e9e41981d7 | ||
|
|
807fa22bfc | ||
|
|
e3881c3bd9 | ||
|
|
45ab4399cc | ||
|
|
9c734f78e8 | ||
|
|
46f4c4399c | ||
|
|
64998bca1b | ||
|
|
34efcc2034 |
@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
|
||||
if torch_dtype is not None:
|
||||
vae = vae.to(torch_dtype)
|
||||
|
||||
vae.eval()
|
||||
return vae
|
||||
|
||||
@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
|
||||
if torch_dtype is not None:
|
||||
controlnet = controlnet.to(torch_dtype)
|
||||
|
||||
controlnet.eval()
|
||||
return controlnet
|
||||
|
||||
@@ -39,6 +39,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_single_file_checkpoint,
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
|
||||
|
||||
if is_torch_version(">=", "1.9.0"):
|
||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||
else:
|
||||
@@ -497,102 +500,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||
)
|
||||
logger.info("Single file checkpoint detected...")
|
||||
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||
model = model.eval()
|
||||
return model
|
||||
else:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
if device_map is not None and not is_accelerate_available():
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||
)
|
||||
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
# Check if we can handle device_map and dispatching the weights
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
user_agent = {
|
||||
"diffusers": __version__,
|
||||
"file_type": "model",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
user_agent=user_agent,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
# load config
|
||||
config, unused_kwargs, commit_hash = cls.load_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
@@ -600,40 +591,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
**kwargs,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
pass
|
||||
if model_file is None:
|
||||
# load model
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
@@ -645,76 +616,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
# Convert the weights
|
||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||
else:
|
||||
if use_safetensors:
|
||||
try:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
except IOError as e:
|
||||
if not allow_pickle:
|
||||
raise e
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warn(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||
" please also re-upload it or open a PR on the original repository."
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
# move the params from meta device to cpu
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
if len(missing_keys) > 0:
|
||||
raise ValueError(
|
||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||
" those weights or else make sure your checkpoint file is correct."
|
||||
)
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
|
||||
unexpected_keys = load_model_dict_into_meta(
|
||||
model,
|
||||
state_dict,
|
||||
device=param_device,
|
||||
dtype=torch_dtype,
|
||||
model_name_or_path=pretrained_model_name_or_path,
|
||||
)
|
||||
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warn(
|
||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||
)
|
||||
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by default the device_map is None and the weights are loaded on the CPU
|
||||
try:
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
@@ -724,52 +709,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
else:
|
||||
raise e
|
||||
except AttributeError as e:
|
||||
# When using accelerate loading, we do not have the ability to load the state
|
||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||
# the weights so we don't have to do this again.
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
if "'Attention' object has no attribute" in str(e):
|
||||
logger.warn(
|
||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||
" please also re-upload it or open a PR on the original repository."
|
||||
)
|
||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||
accelerate.load_checkpoint_and_dispatch(
|
||||
model,
|
||||
model_file,
|
||||
device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
dtype=torch_dtype,
|
||||
)
|
||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||
else:
|
||||
raise e
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
"unexpected_keys": [],
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
elif torch_dtype is not None:
|
||||
model = model.to(torch_dtype)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||
)
|
||||
elif torch_dtype is not None:
|
||||
model = model.to(torch_dtype)
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
if output_loading_info:
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
|
||||
@@ -57,6 +57,7 @@ from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
is_peft_available,
|
||||
is_single_file_checkpoint,
|
||||
is_torch_version,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
|
||||
},
|
||||
}
|
||||
|
||||
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"StableDiffusionPipeline",
|
||||
"StableDiffusionImg2ImgPipeline",
|
||||
"StableDiffusionInpaintPipeline",
|
||||
"StableDiffusionUpscalePipeline",
|
||||
"StableDiffusionControlNetPipeline",
|
||||
"StableDiffusionControlNetImg2ImgPipeline",
|
||||
"StableDiffusionControlNetInpaintPipeline",
|
||||
"StableDiffusionXLPipeline",
|
||||
"StableDiffusionXLImg2ImgPipeline",
|
||||
"StableDiffusionXLInpaintPipeline",
|
||||
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
@@ -1056,308 +1071,334 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
>>> pipeline.scheduler = scheduler
|
||||
```
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
if pretrained_model_name_or_path.count("/") > 1:
|
||||
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||
raise ValueError(
|
||||
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||
)
|
||||
cached_folder = cls.download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
from_flax=from_flax,
|
||||
use_safetensors=use_safetensors,
|
||||
use_onnx=use_onnx,
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
logger.info("Single file checkpoint detected...")
|
||||
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||
return model
|
||||
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
sess_options = kwargs.pop("sess_options", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
offload_folder = kwargs.pop("offload_folder", None)
|
||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
variant = kwargs.pop("variant", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for folder in os.listdir(cached_folder):
|
||||
folder_path = os.path.join(cached_folder, folder)
|
||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||
variant_exists = is_folder and any(
|
||||
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
||||
)
|
||||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
) <= version.parse("0.5.1"):
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
deprecation_message = (
|
||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||
)
|
||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# 4. Define expected modules given pipeline signature
|
||||
# and define non-None initialized modules (=`init_kwargs`)
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs and make sure that optional component modules are filtered out
|
||||
init_kwargs = {
|
||||
k: init_dict.pop(k)
|
||||
for k in optional_kwargs
|
||||
if k in init_dict and k not in pipeline_class._optional_components
|
||||
}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
|
||||
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
|
||||
" separately if you need it."
|
||||
)
|
||||
|
||||
# 5. Throw nice warnings / errors for fast accelerate loading
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
else:
|
||||
# load sub model
|
||||
loaded_sub_model = load_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
model_variants=model_variants,
|
||||
name=name,
|
||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `device_map=None`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is False and device_map is not None:
|
||||
raise ValueError(
|
||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||
)
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
if pretrained_model_name_or_path.count("/") > 1:
|
||||
raise ValueError(
|
||||
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||
)
|
||||
cached_folder = cls.download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
resume_download=resume_download,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
from_flax=from_flax,
|
||||
use_safetensors=use_safetensors,
|
||||
use_onnx=use_onnx,
|
||||
custom_pipeline=custom_pipeline,
|
||||
custom_revision=custom_revision,
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
config_dict = cls.load_config(cached_folder)
|
||||
|
||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
||||
load_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"torch_dtype": torch_dtype,
|
||||
"custom_pipeline": custom_pipeline,
|
||||
"custom_revision": custom_revision,
|
||||
"provider": provider,
|
||||
"sess_options": sess_options,
|
||||
"device_map": device_map,
|
||||
"max_memory": max_memory,
|
||||
"offload_folder": offload_folder,
|
||||
"offload_state_dict": offload_state_dict,
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
# pop out "_ignore_files" as it is only needed for download
|
||||
config_dict.pop("_ignore_files", None)
|
||||
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
||||
}
|
||||
# 2. Define which model components should load variants
|
||||
# We retrieve the information by matching whether variant
|
||||
# model checkpoints exist in the subfolders
|
||||
model_variants = {}
|
||||
if variant is not None:
|
||||
for folder in os.listdir(cached_folder):
|
||||
folder_path = os.path.join(cached_folder, folder)
|
||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||
variant_exists = is_folder and any(
|
||||
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
||||
)
|
||||
if variant_exists:
|
||||
model_variants[folder] = variant
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
return connected_passed_kwargs
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
custom_class_name = None
|
||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
):
|
||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||
custom_class_name = config_dict["_class_name"][1]
|
||||
|
||||
connected_pipes = {
|
||||
prefix: DiffusionPipeline.from_pretrained(
|
||||
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
||||
)
|
||||
for prefix, repo_id in connected_pipes.items()
|
||||
if repo_id is not None
|
||||
}
|
||||
|
||||
for prefix, connected_pipe in connected_pipes.items():
|
||||
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
||||
init_kwargs.update(
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 7. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
pipeline_class = _get_pipeline_class(
|
||||
cls,
|
||||
config_dict,
|
||||
load_connected_pipeline=load_connected_pipeline,
|
||||
custom_pipeline=custom_pipeline,
|
||||
class_name=custom_class_name,
|
||||
cache_dir=cache_dir,
|
||||
revision=custom_revision,
|
||||
)
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
# DEPRECATED: To be removed in 1.0.0
|
||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||
version.parse(config_dict["_diffusers_version"]).base_version
|
||||
) <= version.parse("0.5.1"):
|
||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
return model
|
||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||
|
||||
deprecation_message = (
|
||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||
)
|
||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||
|
||||
# 4. Define expected modules given pipeline signature
|
||||
# and define non-None initialized modules (=`init_kwargs`)
|
||||
|
||||
# some modules can be passed directly to the init
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs and make sure that optional component modules are filtered out
|
||||
init_kwargs = {
|
||||
k: init_dict.pop(k)
|
||||
for k in optional_kwargs
|
||||
if k in init_dict and k not in pipeline_class._optional_components
|
||||
}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
def load_module(name, value):
|
||||
if value[0] is None:
|
||||
return False
|
||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||
return False
|
||||
return True
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
|
||||
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
|
||||
" separately if you need it."
|
||||
)
|
||||
|
||||
# 5. Throw nice warnings / errors for fast accelerate loading
|
||||
if len(unused_kwargs) > 0:
|
||||
logger.warning(
|
||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||
)
|
||||
|
||||
# import it here to avoid circular import
|
||||
from diffusers import pipelines
|
||||
|
||||
# 6. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in logging.tqdm(
|
||||
init_dict.items(), desc="Loading pipeline components..."
|
||||
):
|
||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
if name in passed_class_obj:
|
||||
# if the model is in a pipeline module, then we load it from the pipeline
|
||||
# check that passed_class_obj has correct parent class
|
||||
maybe_raise_or_warn(
|
||||
library_name,
|
||||
library,
|
||||
class_name,
|
||||
importable_classes,
|
||||
passed_class_obj,
|
||||
name,
|
||||
is_pipeline_module,
|
||||
)
|
||||
|
||||
loaded_sub_model = passed_class_obj[name]
|
||||
else:
|
||||
# load sub model
|
||||
loaded_sub_model = load_sub_model(
|
||||
library_name=library_name,
|
||||
class_name=class_name,
|
||||
importable_classes=importable_classes,
|
||||
pipelines=pipelines,
|
||||
is_pipeline_module=is_pipeline_module,
|
||||
pipeline_class=pipeline_class,
|
||||
torch_dtype=torch_dtype,
|
||||
provider=provider,
|
||||
sess_options=sess_options,
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_state_dict=offload_state_dict,
|
||||
model_variants=model_variants,
|
||||
name=name,
|
||||
from_flax=from_flax,
|
||||
variant=variant,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
cached_folder=cached_folder,
|
||||
)
|
||||
logger.info(
|
||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||
)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||
connected_pipes = {
|
||||
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
|
||||
}
|
||||
load_kwargs = {
|
||||
"cache_dir": cache_dir,
|
||||
"resume_download": resume_download,
|
||||
"force_download": force_download,
|
||||
"proxies": proxies,
|
||||
"local_files_only": local_files_only,
|
||||
"token": token,
|
||||
"revision": revision,
|
||||
"torch_dtype": torch_dtype,
|
||||
"custom_pipeline": custom_pipeline,
|
||||
"custom_revision": custom_revision,
|
||||
"provider": provider,
|
||||
"sess_options": sess_options,
|
||||
"device_map": device_map,
|
||||
"max_memory": max_memory,
|
||||
"offload_folder": offload_folder,
|
||||
"offload_state_dict": offload_state_dict,
|
||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||
"variant": variant,
|
||||
"use_safetensors": use_safetensors,
|
||||
}
|
||||
|
||||
def get_connected_passed_kwargs(prefix):
|
||||
connected_passed_class_obj = {
|
||||
k.replace(f"{prefix}_", ""): w
|
||||
for k, w in passed_class_obj.items()
|
||||
if k.split("_")[0] == prefix
|
||||
}
|
||||
connected_passed_pipe_kwargs = {
|
||||
k.replace(f"{prefix}_", ""): w
|
||||
for k, w in passed_pipe_kwargs.items()
|
||||
if k.split("_")[0] == prefix
|
||||
}
|
||||
|
||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||
return connected_passed_kwargs
|
||||
|
||||
connected_pipes = {
|
||||
prefix: DiffusionPipeline.from_pretrained(
|
||||
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
||||
)
|
||||
for prefix, repo_id in connected_pipes.items()
|
||||
if repo_id is not None
|
||||
}
|
||||
|
||||
for prefix, connected_pipe in connected_pipes.items():
|
||||
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
||||
init_kwargs.update(
|
||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||
)
|
||||
|
||||
# 7. Potentially add passed objects if expected
|
||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||
passed_modules = list(passed_class_obj.keys())
|
||||
optional_modules = pipeline_class._optional_components
|
||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||
for module in missing_modules:
|
||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||
elif len(missing_modules) > 0:
|
||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||
raise ValueError(
|
||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||
)
|
||||
|
||||
# 8. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
|
||||
# 9. Save where the model was instantiated from
|
||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||
return model
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
|
||||
@@ -19,6 +19,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS,
|
||||
CONFIG_NAME,
|
||||
DEPRECATED_REVISION_ARGS,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
@@ -83,7 +84,7 @@ from .import_utils import (
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
from .loading_utils import load_image
|
||||
from .loading_utils import is_single_file_checkpoint, load_image
|
||||
from .logging import get_logger
|
||||
from .outputs import BaseOutput
|
||||
from .peft_utils import (
|
||||
|
||||
@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -1,10 +1,28 @@
|
||||
import os
|
||||
from typing import Callable, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import PIL.Image
|
||||
import PIL.ImageOps
|
||||
import requests
|
||||
|
||||
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
|
||||
|
||||
|
||||
def is_single_file_checkpoint(filepath):
|
||||
def is_valid_url(url):
|
||||
result = urlparse(url)
|
||||
if result.scheme and result.netloc:
|
||||
return True
|
||||
|
||||
filepath = str(filepath)
|
||||
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
|
||||
if is_valid_url(filepath):
|
||||
return True
|
||||
elif os.path.isfile(filepath):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def load_image(
|
||||
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
||||
|
||||
Reference in New Issue
Block a user