mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-18 10:24:47 +08:00
Compare commits
12 Commits
custom-cod
...
support-si
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
abf24f116e | ||
|
|
da441cefc6 | ||
|
|
4d2ca28e24 | ||
|
|
d7d757a387 | ||
|
|
e9e41981d7 | ||
|
|
807fa22bfc | ||
|
|
e3881c3bd9 | ||
|
|
45ab4399cc | ||
|
|
9c734f78e8 | ||
|
|
46f4c4399c | ||
|
|
64998bca1b | ||
|
|
34efcc2034 |
@@ -143,4 +143,5 @@ class FromOriginalVAEMixin:
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
vae = vae.to(torch_dtype)
|
vae = vae.to(torch_dtype)
|
||||||
|
|
||||||
|
vae.eval()
|
||||||
return vae
|
return vae
|
||||||
|
|||||||
@@ -133,4 +133,5 @@ class FromOriginalControlNetMixin:
|
|||||||
if torch_dtype is not None:
|
if torch_dtype is not None:
|
||||||
controlnet = controlnet.to(torch_dtype)
|
controlnet = controlnet.to(torch_dtype)
|
||||||
|
|
||||||
|
controlnet.eval()
|
||||||
return controlnet
|
return controlnet
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from ..utils import (
|
|||||||
_get_model_file,
|
_get_model_file,
|
||||||
deprecate,
|
deprecate,
|
||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
|
is_single_file_checkpoint,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
@@ -48,6 +49,8 @@ from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populat
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
SINGLE_FILE_LOADABLE_CLASSES = {"ControlNetModel", "AutoencoderKL"}
|
||||||
|
|
||||||
if is_torch_version(">=", "1.9.0"):
|
if is_torch_version(">=", "1.9.0"):
|
||||||
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
_LOW_CPU_MEM_USAGE_DEFAULT = True
|
||||||
else:
|
else:
|
||||||
@@ -497,102 +500,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
force_download = kwargs.pop("force_download", False)
|
raise ValueError(
|
||||||
from_flax = kwargs.pop("from_flax", False)
|
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
)
|
||||||
proxies = kwargs.pop("proxies", None)
|
logger.info("Single file checkpoint detected...")
|
||||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
model = model.eval()
|
||||||
token = kwargs.pop("token", None)
|
return model
|
||||||
revision = kwargs.pop("revision", None)
|
else:
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
subfolder = kwargs.pop("subfolder", None)
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||||
device_map = kwargs.pop("device_map", None)
|
force_download = kwargs.pop("force_download", False)
|
||||||
max_memory = kwargs.pop("max_memory", None)
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
offload_folder = kwargs.pop("offload_folder", None)
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
proxies = kwargs.pop("proxies", None)
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||||
variant = kwargs.pop("variant", None)
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
token = kwargs.pop("token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
subfolder = kwargs.pop("subfolder", None)
|
||||||
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
max_memory = kwargs.pop("max_memory", None)
|
||||||
|
offload_folder = kwargs.pop("offload_folder", None)
|
||||||
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
variant = kwargs.pop("variant", None)
|
||||||
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
|
||||||
allow_pickle = False
|
allow_pickle = False
|
||||||
if use_safetensors is None:
|
if use_safetensors is None:
|
||||||
use_safetensors = True
|
use_safetensors = True
|
||||||
allow_pickle = True
|
allow_pickle = True
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
low_cpu_mem_usage = False
|
low_cpu_mem_usage = False
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
" install accelerate\n```\n."
|
" install accelerate\n```\n."
|
||||||
)
|
)
|
||||||
|
|
||||||
if device_map is not None and not is_accelerate_available():
|
if device_map is not None and not is_accelerate_available():
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
"Loading and dispatching requires `accelerate`. Please make sure to install accelerate or set"
|
||||||
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
" `device_map=None`. You can install accelerate with `pip install accelerate`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if we can handle device_map and dispatching the weights
|
# Check if we can handle device_map and dispatching the weights
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
" `device_map=None`."
|
" `device_map=None`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
" `low_cpu_mem_usage=False`."
|
" `low_cpu_mem_usage=False`."
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage is False and device_map is not None:
|
if low_cpu_mem_usage is False and device_map is not None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
f"You cannot set `low_cpu_mem_usage` to `False` while using device_map={device_map} for loading and"
|
||||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load config if we don't provide a configuration
|
# Load config if we don't provide a configuration
|
||||||
config_path = pretrained_model_name_or_path
|
config_path = pretrained_model_name_or_path
|
||||||
|
|
||||||
user_agent = {
|
user_agent = {
|
||||||
"diffusers": __version__,
|
"diffusers": __version__,
|
||||||
"file_type": "model",
|
"file_type": "model",
|
||||||
"framework": "pytorch",
|
"framework": "pytorch",
|
||||||
}
|
}
|
||||||
|
|
||||||
# load config
|
# load config
|
||||||
config, unused_kwargs, commit_hash = cls.load_config(
|
config, unused_kwargs, commit_hash = cls.load_config(
|
||||||
config_path,
|
config_path,
|
||||||
cache_dir=cache_dir,
|
|
||||||
return_unused_kwargs=True,
|
|
||||||
return_commit_hash=True,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
device_map=device_map,
|
|
||||||
max_memory=max_memory,
|
|
||||||
offload_folder=offload_folder,
|
|
||||||
offload_state_dict=offload_state_dict,
|
|
||||||
user_agent=user_agent,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
# load model
|
|
||||||
model_file = None
|
|
||||||
if from_flax:
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
weights_name=FLAX_WEIGHTS_NAME,
|
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
|
return_unused_kwargs=True,
|
||||||
|
return_commit_hash=True,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
proxies=proxies,
|
proxies=proxies,
|
||||||
@@ -600,40 +591,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
token=token,
|
token=token,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
subfolder=subfolder,
|
subfolder=subfolder,
|
||||||
|
device_map=device_map,
|
||||||
|
max_memory=max_memory,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=offload_state_dict,
|
||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
commit_hash=commit_hash,
|
**kwargs,
|
||||||
)
|
)
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
|
||||||
|
|
||||||
# Convert the weights
|
# load model
|
||||||
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
model_file = None
|
||||||
|
if from_flax:
|
||||||
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
|
||||||
else:
|
|
||||||
if use_safetensors:
|
|
||||||
try:
|
|
||||||
model_file = _get_model_file(
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
force_download=force_download,
|
|
||||||
resume_download=resume_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
subfolder=subfolder,
|
|
||||||
user_agent=user_agent,
|
|
||||||
commit_hash=commit_hash,
|
|
||||||
)
|
|
||||||
except IOError as e:
|
|
||||||
if not allow_pickle:
|
|
||||||
raise e
|
|
||||||
pass
|
|
||||||
if model_file is None:
|
|
||||||
model_file = _get_model_file(
|
model_file = _get_model_file(
|
||||||
pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
weights_name=FLAX_WEIGHTS_NAME,
|
||||||
cache_dir=cache_dir,
|
cache_dir=cache_dir,
|
||||||
force_download=force_download,
|
force_download=force_download,
|
||||||
resume_download=resume_download,
|
resume_download=resume_download,
|
||||||
@@ -645,76 +616,90 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
# Convert the weights
|
||||||
# Instantiate model with empty weights
|
from .modeling_pytorch_flax_utils import load_flax_checkpoint_in_pytorch_model
|
||||||
with accelerate.init_empty_weights():
|
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
|
||||||
|
|
||||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
model = load_flax_checkpoint_in_pytorch_model(model, model_file)
|
||||||
if device_map is None:
|
else:
|
||||||
param_device = "cpu"
|
if use_safetensors:
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
try:
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
model_file = _get_model_file(
|
||||||
# move the params from meta device to cpu
|
pretrained_model_name_or_path,
|
||||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
||||||
if len(missing_keys) > 0:
|
cache_dir=cache_dir,
|
||||||
raise ValueError(
|
force_download=force_download,
|
||||||
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
resume_download=resume_download,
|
||||||
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
proxies=proxies,
|
||||||
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
local_files_only=local_files_only,
|
||||||
" those weights or else make sure your checkpoint file is correct."
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
except IOError as e:
|
||||||
unexpected_keys = load_model_dict_into_meta(
|
if not allow_pickle:
|
||||||
model,
|
raise e
|
||||||
state_dict,
|
pass
|
||||||
device=param_device,
|
if model_file is None:
|
||||||
dtype=torch_dtype,
|
model_file = _get_model_file(
|
||||||
model_name_or_path=pretrained_model_name_or_path,
|
pretrained_model_name_or_path,
|
||||||
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
|
subfolder=subfolder,
|
||||||
|
user_agent=user_agent,
|
||||||
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
if low_cpu_mem_usage:
|
||||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
# Instantiate model with empty weights
|
||||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
with accelerate.init_empty_weights():
|
||||||
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||||
logger.warn(
|
if device_map is None:
|
||||||
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
param_device = "cpu"
|
||||||
)
|
state_dict = load_state_dict(model_file, variant=variant)
|
||||||
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
else: # else let accelerate handle loading and dispatching.
|
# move the params from meta device to cpu
|
||||||
# Load weights and dispatch according to the device_map
|
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||||
# by default the device_map is None and the weights are loaded on the CPU
|
if len(missing_keys) > 0:
|
||||||
try:
|
raise ValueError(
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
||||||
model,
|
f" missing: \n {', '.join(missing_keys)}. \n Please make sure to pass"
|
||||||
model_file,
|
" `low_cpu_mem_usage=False` and `device_map=None` if you want to randomly initialize"
|
||||||
device_map,
|
" those weights or else make sure your checkpoint file is correct."
|
||||||
max_memory=max_memory,
|
|
||||||
offload_folder=offload_folder,
|
|
||||||
offload_state_dict=offload_state_dict,
|
|
||||||
dtype=torch_dtype,
|
|
||||||
)
|
|
||||||
except AttributeError as e:
|
|
||||||
# When using accelerate loading, we do not have the ability to load the state
|
|
||||||
# dict and rename the weight names manually. Additionally, accelerate skips
|
|
||||||
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
|
||||||
# (which look like they should be private variables?), so we can't use the standard hooks
|
|
||||||
# to rename parameters on load. We need to mimic the original weight names so the correct
|
|
||||||
# attributes are available. After we have loaded the weights, we convert the deprecated
|
|
||||||
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
|
||||||
# the weights so we don't have to do this again.
|
|
||||||
|
|
||||||
if "'Attention' object has no attribute" in str(e):
|
|
||||||
logger.warn(
|
|
||||||
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
|
||||||
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
|
||||||
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
|
||||||
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
|
||||||
" please also re-upload it or open a PR on the original repository."
|
|
||||||
)
|
)
|
||||||
model._temp_convert_self_to_deprecated_attention_blocks()
|
|
||||||
|
unexpected_keys = load_model_dict_into_meta(
|
||||||
|
model,
|
||||||
|
state_dict,
|
||||||
|
device=param_device,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
model_name_or_path=pretrained_model_name_or_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||||
|
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||||
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||||
|
|
||||||
|
if len(unexpected_keys) > 0:
|
||||||
|
logger.warn(
|
||||||
|
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
else: # else let accelerate handle loading and dispatching.
|
||||||
|
# Load weights and dispatch according to the device_map
|
||||||
|
# by default the device_map is None and the weights are loaded on the CPU
|
||||||
|
try:
|
||||||
accelerate.load_checkpoint_and_dispatch(
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
model,
|
model,
|
||||||
model_file,
|
model_file,
|
||||||
@@ -724,52 +709,80 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
offload_state_dict=offload_state_dict,
|
offload_state_dict=offload_state_dict,
|
||||||
dtype=torch_dtype,
|
dtype=torch_dtype,
|
||||||
)
|
)
|
||||||
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
except AttributeError as e:
|
||||||
else:
|
# When using accelerate loading, we do not have the ability to load the state
|
||||||
raise e
|
# dict and rename the weight names manually. Additionally, accelerate skips
|
||||||
|
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
||||||
|
# (which look like they should be private variables?), so we can't use the standard hooks
|
||||||
|
# to rename parameters on load. We need to mimic the original weight names so the correct
|
||||||
|
# attributes are available. After we have loaded the weights, we convert the deprecated
|
||||||
|
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
||||||
|
# the weights so we don't have to do this again.
|
||||||
|
|
||||||
loading_info = {
|
if "'Attention' object has no attribute" in str(e):
|
||||||
"missing_keys": [],
|
logger.warn(
|
||||||
"unexpected_keys": [],
|
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
||||||
"mismatched_keys": [],
|
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
||||||
"error_msgs": [],
|
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
||||||
}
|
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
||||||
else:
|
" please also re-upload it or open a PR on the original repository."
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
)
|
||||||
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
||||||
|
accelerate.load_checkpoint_and_dispatch(
|
||||||
|
model,
|
||||||
|
model_file,
|
||||||
|
device_map,
|
||||||
|
max_memory=max_memory,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=offload_state_dict,
|
||||||
|
dtype=torch_dtype,
|
||||||
|
)
|
||||||
|
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
||||||
|
else:
|
||||||
|
raise e
|
||||||
|
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
loading_info = {
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
"missing_keys": [],
|
||||||
|
"unexpected_keys": [],
|
||||||
|
"mismatched_keys": [],
|
||||||
|
"error_msgs": [],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
|
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
state_dict = load_state_dict(model_file, variant=variant)
|
||||||
model,
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
state_dict,
|
|
||||||
model_file,
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
pretrained_model_name_or_path,
|
model,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
state_dict,
|
||||||
|
model_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"mismatched_keys": mismatched_keys,
|
||||||
|
"error_msgs": error_msgs,
|
||||||
|
}
|
||||||
|
|
||||||
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||||
|
raise ValueError(
|
||||||
|
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
||||||
)
|
)
|
||||||
|
elif torch_dtype is not None:
|
||||||
|
model = model.to(torch_dtype)
|
||||||
|
|
||||||
loading_info = {
|
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||||
"missing_keys": missing_keys,
|
|
||||||
"unexpected_keys": unexpected_keys,
|
|
||||||
"mismatched_keys": mismatched_keys,
|
|
||||||
"error_msgs": error_msgs,
|
|
||||||
}
|
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||||
raise ValueError(
|
model.eval()
|
||||||
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
if output_loading_info:
|
||||||
)
|
return model, loading_info
|
||||||
elif torch_dtype is not None:
|
|
||||||
model = model.to(torch_dtype)
|
|
||||||
|
|
||||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
return model
|
||||||
|
|
||||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
|
||||||
model.eval()
|
|
||||||
if output_loading_info:
|
|
||||||
return model, loading_info
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _load_pretrained_model(
|
def _load_pretrained_model(
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from ..utils import (
|
|||||||
is_accelerate_available,
|
is_accelerate_available,
|
||||||
is_accelerate_version,
|
is_accelerate_version,
|
||||||
is_peft_available,
|
is_peft_available,
|
||||||
|
is_single_file_checkpoint,
|
||||||
is_torch_version,
|
is_torch_version,
|
||||||
is_transformers_available,
|
is_transformers_available,
|
||||||
logging,
|
logging,
|
||||||
@@ -110,6 +111,20 @@ LOADABLE_CLASSES = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SINGLE_FILE_LOADABLE_CLASSES = {
|
||||||
|
"StableDiffusionPipeline",
|
||||||
|
"StableDiffusionImg2ImgPipeline",
|
||||||
|
"StableDiffusionInpaintPipeline",
|
||||||
|
"StableDiffusionUpscalePipeline",
|
||||||
|
"StableDiffusionControlNetPipeline",
|
||||||
|
"StableDiffusionControlNetImg2ImgPipeline",
|
||||||
|
"StableDiffusionControlNetInpaintPipeline",
|
||||||
|
"StableDiffusionXLPipeline",
|
||||||
|
"StableDiffusionXLImg2ImgPipeline",
|
||||||
|
"StableDiffusionXLInpaintPipeline",
|
||||||
|
"StableDiffusionXLControlNetImg2ImgPipeline",
|
||||||
|
}
|
||||||
|
|
||||||
ALL_IMPORTABLE_CLASSES = {}
|
ALL_IMPORTABLE_CLASSES = {}
|
||||||
for library in LOADABLE_CLASSES:
|
for library in LOADABLE_CLASSES:
|
||||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||||
@@ -1056,308 +1071,334 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
>>> pipeline.scheduler = scheduler
|
>>> pipeline.scheduler = scheduler
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
cache_dir = kwargs.pop("cache_dir", None)
|
if is_single_file_checkpoint(pretrained_model_name_or_path):
|
||||||
resume_download = kwargs.pop("resume_download", False)
|
if cls.__name__ not in SINGLE_FILE_LOADABLE_CLASSES:
|
||||||
force_download = kwargs.pop("force_download", False)
|
|
||||||
proxies = kwargs.pop("proxies", None)
|
|
||||||
local_files_only = kwargs.pop("local_files_only", None)
|
|
||||||
token = kwargs.pop("token", None)
|
|
||||||
revision = kwargs.pop("revision", None)
|
|
||||||
from_flax = kwargs.pop("from_flax", False)
|
|
||||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
|
||||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
|
||||||
custom_revision = kwargs.pop("custom_revision", None)
|
|
||||||
provider = kwargs.pop("provider", None)
|
|
||||||
sess_options = kwargs.pop("sess_options", None)
|
|
||||||
device_map = kwargs.pop("device_map", None)
|
|
||||||
max_memory = kwargs.pop("max_memory", None)
|
|
||||||
offload_folder = kwargs.pop("offload_folder", None)
|
|
||||||
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
|
||||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
|
||||||
variant = kwargs.pop("variant", None)
|
|
||||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
|
||||||
use_onnx = kwargs.pop("use_onnx", None)
|
|
||||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage and not is_accelerate_available():
|
|
||||||
low_cpu_mem_usage = False
|
|
||||||
logger.warning(
|
|
||||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
|
||||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
|
||||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
|
||||||
" install accelerate\n```\n."
|
|
||||||
)
|
|
||||||
|
|
||||||
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `device_map=None`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
|
||||||
raise NotImplementedError(
|
|
||||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
|
||||||
" `low_cpu_mem_usage=False`."
|
|
||||||
)
|
|
||||||
|
|
||||||
if low_cpu_mem_usage is False and device_map is not None:
|
|
||||||
raise ValueError(
|
|
||||||
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
|
||||||
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1. Download the checkpoints and configs
|
|
||||||
# use snapshot download here to get it working from from_pretrained
|
|
||||||
if not os.path.isdir(pretrained_model_name_or_path):
|
|
||||||
if pretrained_model_name_or_path.count("/") > 1:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||||
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||||
|
f"{cls.__name__} is not supported. Supported classes are: {' '.join(list(SINGLE_FILE_LOADABLE_CLASSES))}."
|
||||||
)
|
)
|
||||||
cached_folder = cls.download(
|
logger.info("Single file checkpoint detected...")
|
||||||
pretrained_model_name_or_path,
|
model = cls.from_single_file(pretrained_model_name_or_path, **kwargs)
|
||||||
cache_dir=cache_dir,
|
return model
|
||||||
resume_download=resume_download,
|
|
||||||
force_download=force_download,
|
|
||||||
proxies=proxies,
|
|
||||||
local_files_only=local_files_only,
|
|
||||||
token=token,
|
|
||||||
revision=revision,
|
|
||||||
from_flax=from_flax,
|
|
||||||
use_safetensors=use_safetensors,
|
|
||||||
use_onnx=use_onnx,
|
|
||||||
custom_pipeline=custom_pipeline,
|
|
||||||
custom_revision=custom_revision,
|
|
||||||
variant=variant,
|
|
||||||
load_connected_pipeline=load_connected_pipeline,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
cached_folder = pretrained_model_name_or_path
|
cache_dir = kwargs.pop("cache_dir", None)
|
||||||
|
resume_download = kwargs.pop("resume_download", False)
|
||||||
|
force_download = kwargs.pop("force_download", False)
|
||||||
|
proxies = kwargs.pop("proxies", None)
|
||||||
|
local_files_only = kwargs.pop("local_files_only", None)
|
||||||
|
token = kwargs.pop("token", None)
|
||||||
|
revision = kwargs.pop("revision", None)
|
||||||
|
from_flax = kwargs.pop("from_flax", False)
|
||||||
|
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||||
|
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||||
|
custom_revision = kwargs.pop("custom_revision", None)
|
||||||
|
provider = kwargs.pop("provider", None)
|
||||||
|
sess_options = kwargs.pop("sess_options", None)
|
||||||
|
device_map = kwargs.pop("device_map", None)
|
||||||
|
max_memory = kwargs.pop("max_memory", None)
|
||||||
|
offload_folder = kwargs.pop("offload_folder", None)
|
||||||
|
offload_state_dict = kwargs.pop("offload_state_dict", False)
|
||||||
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||||
|
variant = kwargs.pop("variant", None)
|
||||||
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||||
|
use_onnx = kwargs.pop("use_onnx", None)
|
||||||
|
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||||
|
|
||||||
config_dict = cls.load_config(cached_folder)
|
if low_cpu_mem_usage and not is_accelerate_available():
|
||||||
|
low_cpu_mem_usage = False
|
||||||
# pop out "_ignore_files" as it is only needed for download
|
logger.warning(
|
||||||
config_dict.pop("_ignore_files", None)
|
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||||
|
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||||
# 2. Define which model components should load variants
|
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||||
# We retrieve the information by matching whether variant
|
" install accelerate\n```\n."
|
||||||
# model checkpoints exist in the subfolders
|
|
||||||
model_variants = {}
|
|
||||||
if variant is not None:
|
|
||||||
for folder in os.listdir(cached_folder):
|
|
||||||
folder_path = os.path.join(cached_folder, folder)
|
|
||||||
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
|
||||||
variant_exists = is_folder and any(
|
|
||||||
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
|
||||||
)
|
|
||||||
if variant_exists:
|
|
||||||
model_variants[folder] = variant
|
|
||||||
|
|
||||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
|
||||||
# if we load from explicit class, let's use it
|
|
||||||
custom_class_name = None
|
|
||||||
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
|
||||||
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
|
||||||
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
|
||||||
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
|
||||||
):
|
|
||||||
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
|
||||||
custom_class_name = config_dict["_class_name"][1]
|
|
||||||
|
|
||||||
pipeline_class = _get_pipeline_class(
|
|
||||||
cls,
|
|
||||||
config_dict,
|
|
||||||
load_connected_pipeline=load_connected_pipeline,
|
|
||||||
custom_pipeline=custom_pipeline,
|
|
||||||
class_name=custom_class_name,
|
|
||||||
cache_dir=cache_dir,
|
|
||||||
revision=custom_revision,
|
|
||||||
)
|
|
||||||
|
|
||||||
# DEPRECATED: To be removed in 1.0.0
|
|
||||||
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
|
||||||
version.parse(config_dict["_diffusers_version"]).base_version
|
|
||||||
) <= version.parse("0.5.1"):
|
|
||||||
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
|
||||||
|
|
||||||
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
|
||||||
|
|
||||||
deprecation_message = (
|
|
||||||
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
|
||||||
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
|
||||||
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
|
||||||
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
|
||||||
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
|
||||||
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
|
||||||
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
|
||||||
)
|
|
||||||
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
|
||||||
|
|
||||||
# 4. Define expected modules given pipeline signature
|
|
||||||
# and define non-None initialized modules (=`init_kwargs`)
|
|
||||||
|
|
||||||
# some modules can be passed directly to the init
|
|
||||||
# in this case they are already instantiated in `kwargs`
|
|
||||||
# extract them here
|
|
||||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
|
||||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
|
||||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
|
||||||
|
|
||||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
|
||||||
|
|
||||||
# define init kwargs and make sure that optional component modules are filtered out
|
|
||||||
init_kwargs = {
|
|
||||||
k: init_dict.pop(k)
|
|
||||||
for k in optional_kwargs
|
|
||||||
if k in init_dict and k not in pipeline_class._optional_components
|
|
||||||
}
|
|
||||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
|
||||||
|
|
||||||
# remove `null` components
|
|
||||||
def load_module(name, value):
|
|
||||||
if value[0] is None:
|
|
||||||
return False
|
|
||||||
if name in passed_class_obj and passed_class_obj[name] is None:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
|
||||||
|
|
||||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
|
||||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
|
||||||
raise NotImplementedError(
|
|
||||||
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
|
|
||||||
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
|
|
||||||
" separately if you need it."
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Throw nice warnings / errors for fast accelerate loading
|
|
||||||
if len(unused_kwargs) > 0:
|
|
||||||
logger.warning(
|
|
||||||
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
|
||||||
)
|
|
||||||
|
|
||||||
# import it here to avoid circular import
|
|
||||||
from diffusers import pipelines
|
|
||||||
|
|
||||||
# 6. Load each module in the pipeline
|
|
||||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
|
||||||
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
|
||||||
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
|
||||||
|
|
||||||
# 6.2 Define all importable classes
|
|
||||||
is_pipeline_module = hasattr(pipelines, library_name)
|
|
||||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
|
||||||
loaded_sub_model = None
|
|
||||||
|
|
||||||
# 6.3 Use passed sub model or load class_name from library_name
|
|
||||||
if name in passed_class_obj:
|
|
||||||
# if the model is in a pipeline module, then we load it from the pipeline
|
|
||||||
# check that passed_class_obj has correct parent class
|
|
||||||
maybe_raise_or_warn(
|
|
||||||
library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loaded_sub_model = passed_class_obj[name]
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
||||||
else:
|
raise NotImplementedError(
|
||||||
# load sub model
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
loaded_sub_model = load_sub_model(
|
" `device_map=None`."
|
||||||
library_name=library_name,
|
)
|
||||||
class_name=class_name,
|
|
||||||
importable_classes=importable_classes,
|
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||||
pipelines=pipelines,
|
raise NotImplementedError(
|
||||||
is_pipeline_module=is_pipeline_module,
|
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||||
pipeline_class=pipeline_class,
|
" `low_cpu_mem_usage=False`."
|
||||||
torch_dtype=torch_dtype,
|
)
|
||||||
provider=provider,
|
|
||||||
sess_options=sess_options,
|
if low_cpu_mem_usage is False and device_map is not None:
|
||||||
device_map=device_map,
|
raise ValueError(
|
||||||
max_memory=max_memory,
|
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
||||||
offload_folder=offload_folder,
|
" dispatching. Please make sure to set `low_cpu_mem_usage=True`."
|
||||||
offload_state_dict=offload_state_dict,
|
)
|
||||||
model_variants=model_variants,
|
|
||||||
name=name,
|
# 1. Download the checkpoints and configs
|
||||||
|
# use snapshot download here to get it working from from_pretrained
|
||||||
|
if not os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
if pretrained_model_name_or_path.count("/") > 1:
|
||||||
|
raise ValueError(
|
||||||
|
f'The provided pretrained_model_name_or_path "{pretrained_model_name_or_path}"'
|
||||||
|
" is neither a valid local path nor a valid repo id. Please check the parameter."
|
||||||
|
)
|
||||||
|
cached_folder = cls.download(
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
resume_download=resume_download,
|
||||||
|
force_download=force_download,
|
||||||
|
proxies=proxies,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
token=token,
|
||||||
|
revision=revision,
|
||||||
from_flax=from_flax,
|
from_flax=from_flax,
|
||||||
|
use_safetensors=use_safetensors,
|
||||||
|
use_onnx=use_onnx,
|
||||||
|
custom_pipeline=custom_pipeline,
|
||||||
|
custom_revision=custom_revision,
|
||||||
variant=variant,
|
variant=variant,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
load_connected_pipeline=load_connected_pipeline,
|
||||||
cached_folder=cached_folder,
|
**kwargs,
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
cached_folder = pretrained_model_name_or_path
|
||||||
|
|
||||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
config_dict = cls.load_config(cached_folder)
|
||||||
|
|
||||||
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
# pop out "_ignore_files" as it is only needed for download
|
||||||
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
config_dict.pop("_ignore_files", None)
|
||||||
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
|
||||||
load_kwargs = {
|
|
||||||
"cache_dir": cache_dir,
|
|
||||||
"resume_download": resume_download,
|
|
||||||
"force_download": force_download,
|
|
||||||
"proxies": proxies,
|
|
||||||
"local_files_only": local_files_only,
|
|
||||||
"token": token,
|
|
||||||
"revision": revision,
|
|
||||||
"torch_dtype": torch_dtype,
|
|
||||||
"custom_pipeline": custom_pipeline,
|
|
||||||
"custom_revision": custom_revision,
|
|
||||||
"provider": provider,
|
|
||||||
"sess_options": sess_options,
|
|
||||||
"device_map": device_map,
|
|
||||||
"max_memory": max_memory,
|
|
||||||
"offload_folder": offload_folder,
|
|
||||||
"offload_state_dict": offload_state_dict,
|
|
||||||
"low_cpu_mem_usage": low_cpu_mem_usage,
|
|
||||||
"variant": variant,
|
|
||||||
"use_safetensors": use_safetensors,
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_connected_passed_kwargs(prefix):
|
# 2. Define which model components should load variants
|
||||||
connected_passed_class_obj = {
|
# We retrieve the information by matching whether variant
|
||||||
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
# model checkpoints exist in the subfolders
|
||||||
}
|
model_variants = {}
|
||||||
connected_passed_pipe_kwargs = {
|
if variant is not None:
|
||||||
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
for folder in os.listdir(cached_folder):
|
||||||
}
|
folder_path = os.path.join(cached_folder, folder)
|
||||||
|
is_folder = os.path.isdir(folder_path) and folder in config_dict
|
||||||
|
variant_exists = is_folder and any(
|
||||||
|
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
||||||
|
)
|
||||||
|
if variant_exists:
|
||||||
|
model_variants[folder] = variant
|
||||||
|
|
||||||
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||||
return connected_passed_kwargs
|
# if we load from explicit class, let's use it
|
||||||
|
custom_class_name = None
|
||||||
|
if os.path.isfile(os.path.join(cached_folder, f"{custom_pipeline}.py")):
|
||||||
|
custom_pipeline = os.path.join(cached_folder, f"{custom_pipeline}.py")
|
||||||
|
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
||||||
|
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||||
|
):
|
||||||
|
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
||||||
|
custom_class_name = config_dict["_class_name"][1]
|
||||||
|
|
||||||
connected_pipes = {
|
pipeline_class = _get_pipeline_class(
|
||||||
prefix: DiffusionPipeline.from_pretrained(
|
cls,
|
||||||
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
config_dict,
|
||||||
)
|
load_connected_pipeline=load_connected_pipeline,
|
||||||
for prefix, repo_id in connected_pipes.items()
|
custom_pipeline=custom_pipeline,
|
||||||
if repo_id is not None
|
class_name=custom_class_name,
|
||||||
}
|
cache_dir=cache_dir,
|
||||||
|
revision=custom_revision,
|
||||||
for prefix, connected_pipe in connected_pipes.items():
|
|
||||||
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
|
||||||
init_kwargs.update(
|
|
||||||
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7. Potentially add passed objects if expected
|
|
||||||
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
|
||||||
passed_modules = list(passed_class_obj.keys())
|
|
||||||
optional_modules = pipeline_class._optional_components
|
|
||||||
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
|
||||||
for module in missing_modules:
|
|
||||||
init_kwargs[module] = passed_class_obj.get(module, None)
|
|
||||||
elif len(missing_modules) > 0:
|
|
||||||
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
|
||||||
raise ValueError(
|
|
||||||
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 8. Instantiate the pipeline
|
# DEPRECATED: To be removed in 1.0.0
|
||||||
model = pipeline_class(**init_kwargs)
|
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
|
||||||
|
version.parse(config_dict["_diffusers_version"]).base_version
|
||||||
|
) <= version.parse("0.5.1"):
|
||||||
|
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
|
||||||
|
|
||||||
# 9. Save where the model was instantiated from
|
pipeline_class = StableDiffusionInpaintPipelineLegacy
|
||||||
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
|
||||||
return model
|
deprecation_message = (
|
||||||
|
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
||||||
|
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
||||||
|
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
||||||
|
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
||||||
|
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
||||||
|
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
||||||
|
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
||||||
|
)
|
||||||
|
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
||||||
|
|
||||||
|
# 4. Define expected modules given pipeline signature
|
||||||
|
# and define non-None initialized modules (=`init_kwargs`)
|
||||||
|
|
||||||
|
# some modules can be passed directly to the init
|
||||||
|
# in this case they are already instantiated in `kwargs`
|
||||||
|
# extract them here
|
||||||
|
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||||
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||||
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||||
|
|
||||||
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||||
|
|
||||||
|
# define init kwargs and make sure that optional component modules are filtered out
|
||||||
|
init_kwargs = {
|
||||||
|
k: init_dict.pop(k)
|
||||||
|
for k in optional_kwargs
|
||||||
|
if k in init_dict and k not in pipeline_class._optional_components
|
||||||
|
}
|
||||||
|
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||||
|
|
||||||
|
# remove `null` components
|
||||||
|
def load_module(name, value):
|
||||||
|
if value[0] is None:
|
||||||
|
return False
|
||||||
|
if name in passed_class_obj and passed_class_obj[name] is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||||
|
|
||||||
|
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||||
|
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The safety checker cannot be automatically loaded when loading weights `from_flax`."
|
||||||
|
" Please, pass `safety_checker=None` to `from_pretrained`, and load the safety checker"
|
||||||
|
" separately if you need it."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Throw nice warnings / errors for fast accelerate loading
|
||||||
|
if len(unused_kwargs) > 0:
|
||||||
|
logger.warning(
|
||||||
|
f"Keyword arguments {unused_kwargs} are not expected by {pipeline_class.__name__} and will be ignored."
|
||||||
|
)
|
||||||
|
|
||||||
|
# import it here to avoid circular import
|
||||||
|
from diffusers import pipelines
|
||||||
|
|
||||||
|
# 6. Load each module in the pipeline
|
||||||
|
for name, (library_name, class_name) in logging.tqdm(
|
||||||
|
init_dict.items(), desc="Loading pipeline components..."
|
||||||
|
):
|
||||||
|
# 6.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||||
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
||||||
|
|
||||||
|
# 6.2 Define all importable classes
|
||||||
|
is_pipeline_module = hasattr(pipelines, library_name)
|
||||||
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||||
|
loaded_sub_model = None
|
||||||
|
|
||||||
|
# 6.3 Use passed sub model or load class_name from library_name
|
||||||
|
if name in passed_class_obj:
|
||||||
|
# if the model is in a pipeline module, then we load it from the pipeline
|
||||||
|
# check that passed_class_obj has correct parent class
|
||||||
|
maybe_raise_or_warn(
|
||||||
|
library_name,
|
||||||
|
library,
|
||||||
|
class_name,
|
||||||
|
importable_classes,
|
||||||
|
passed_class_obj,
|
||||||
|
name,
|
||||||
|
is_pipeline_module,
|
||||||
|
)
|
||||||
|
|
||||||
|
loaded_sub_model = passed_class_obj[name]
|
||||||
|
else:
|
||||||
|
# load sub model
|
||||||
|
loaded_sub_model = load_sub_model(
|
||||||
|
library_name=library_name,
|
||||||
|
class_name=class_name,
|
||||||
|
importable_classes=importable_classes,
|
||||||
|
pipelines=pipelines,
|
||||||
|
is_pipeline_module=is_pipeline_module,
|
||||||
|
pipeline_class=pipeline_class,
|
||||||
|
torch_dtype=torch_dtype,
|
||||||
|
provider=provider,
|
||||||
|
sess_options=sess_options,
|
||||||
|
device_map=device_map,
|
||||||
|
max_memory=max_memory,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=offload_state_dict,
|
||||||
|
model_variants=model_variants,
|
||||||
|
name=name,
|
||||||
|
from_flax=from_flax,
|
||||||
|
variant=variant,
|
||||||
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
cached_folder=cached_folder,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
||||||
|
)
|
||||||
|
|
||||||
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||||
|
|
||||||
|
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
||||||
|
modelcard = ModelCard.load(os.path.join(cached_folder, "README.md"))
|
||||||
|
connected_pipes = {
|
||||||
|
prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS
|
||||||
|
}
|
||||||
|
load_kwargs = {
|
||||||
|
"cache_dir": cache_dir,
|
||||||
|
"resume_download": resume_download,
|
||||||
|
"force_download": force_download,
|
||||||
|
"proxies": proxies,
|
||||||
|
"local_files_only": local_files_only,
|
||||||
|
"token": token,
|
||||||
|
"revision": revision,
|
||||||
|
"torch_dtype": torch_dtype,
|
||||||
|
"custom_pipeline": custom_pipeline,
|
||||||
|
"custom_revision": custom_revision,
|
||||||
|
"provider": provider,
|
||||||
|
"sess_options": sess_options,
|
||||||
|
"device_map": device_map,
|
||||||
|
"max_memory": max_memory,
|
||||||
|
"offload_folder": offload_folder,
|
||||||
|
"offload_state_dict": offload_state_dict,
|
||||||
|
"low_cpu_mem_usage": low_cpu_mem_usage,
|
||||||
|
"variant": variant,
|
||||||
|
"use_safetensors": use_safetensors,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_connected_passed_kwargs(prefix):
|
||||||
|
connected_passed_class_obj = {
|
||||||
|
k.replace(f"{prefix}_", ""): w
|
||||||
|
for k, w in passed_class_obj.items()
|
||||||
|
if k.split("_")[0] == prefix
|
||||||
|
}
|
||||||
|
connected_passed_pipe_kwargs = {
|
||||||
|
k.replace(f"{prefix}_", ""): w
|
||||||
|
for k, w in passed_pipe_kwargs.items()
|
||||||
|
if k.split("_")[0] == prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
||||||
|
return connected_passed_kwargs
|
||||||
|
|
||||||
|
connected_pipes = {
|
||||||
|
prefix: DiffusionPipeline.from_pretrained(
|
||||||
|
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
||||||
|
)
|
||||||
|
for prefix, repo_id in connected_pipes.items()
|
||||||
|
if repo_id is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
for prefix, connected_pipe in connected_pipes.items():
|
||||||
|
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
||||||
|
init_kwargs.update(
|
||||||
|
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 7. Potentially add passed objects if expected
|
||||||
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
||||||
|
passed_modules = list(passed_class_obj.keys())
|
||||||
|
optional_modules = pipeline_class._optional_components
|
||||||
|
if len(missing_modules) > 0 and missing_modules <= set(passed_modules + optional_modules):
|
||||||
|
for module in missing_modules:
|
||||||
|
init_kwargs[module] = passed_class_obj.get(module, None)
|
||||||
|
elif len(missing_modules) > 0:
|
||||||
|
passed_modules = set(list(init_kwargs.keys()) + list(passed_class_obj.keys())) - optional_kwargs
|
||||||
|
raise ValueError(
|
||||||
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
||||||
|
)
|
||||||
|
|
||||||
|
# 8. Instantiate the pipeline
|
||||||
|
model = pipeline_class(**init_kwargs)
|
||||||
|
|
||||||
|
# 9. Save where the model was instantiated from
|
||||||
|
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
||||||
|
return model
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name_or_path(self) -> str:
|
def name_or_path(self) -> str:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from packaging import version
|
|||||||
|
|
||||||
from .. import __version__
|
from .. import __version__
|
||||||
from .constants import (
|
from .constants import (
|
||||||
|
_ACCEPTED_SINGLE_FILE_FORMATS,
|
||||||
CONFIG_NAME,
|
CONFIG_NAME,
|
||||||
DEPRECATED_REVISION_ARGS,
|
DEPRECATED_REVISION_ARGS,
|
||||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||||
@@ -83,7 +84,7 @@ from .import_utils import (
|
|||||||
is_xformers_available,
|
is_xformers_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
)
|
)
|
||||||
from .loading_utils import load_image
|
from .loading_utils import is_single_file_checkpoint, load_image
|
||||||
from .logging import get_logger
|
from .logging import get_logger
|
||||||
from .outputs import BaseOutput
|
from .outputs import BaseOutput
|
||||||
from .peft_utils import (
|
from .peft_utils import (
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://hugging
|
|||||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(HF_HOME, "modules"))
|
||||||
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
|
||||||
|
_ACCEPTED_SINGLE_FILE_FORMATS = (".safetensors", ".ckpt", ".bin", ".pth", ".pt")
|
||||||
|
|
||||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||||
|
|||||||
@@ -1,10 +1,28 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Callable, Union
|
from typing import Callable, Union
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
import PIL.Image
|
import PIL.Image
|
||||||
import PIL.ImageOps
|
import PIL.ImageOps
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
from ..utils.constants import _ACCEPTED_SINGLE_FILE_FORMATS
|
||||||
|
|
||||||
|
|
||||||
|
def is_single_file_checkpoint(filepath):
|
||||||
|
def is_valid_url(url):
|
||||||
|
result = urlparse(url)
|
||||||
|
if result.scheme and result.netloc:
|
||||||
|
return True
|
||||||
|
|
||||||
|
filepath = str(filepath)
|
||||||
|
if filepath.endswith(_ACCEPTED_SINGLE_FILE_FORMATS):
|
||||||
|
if is_valid_url(filepath):
|
||||||
|
return True
|
||||||
|
elif os.path.isfile(filepath):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def load_image(
|
def load_image(
|
||||||
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
image: Union[str, PIL.Image.Image], convert_method: Callable[[PIL.Image.Image], PIL.Image.Image] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user