mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 13:34:27 +08:00
Compare commits
1 Commits
fix-timeou
...
muellerzr-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea44f69be3 |
@@ -165,20 +165,10 @@ def load_model_dict_into_meta(
|
||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
# state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix: str = ""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model_to_load)
|
||||
model_to_load.load_state_dict(state_dict, assign=True, strict=False)
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
@@ -734,13 +734,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage:
|
||||
# Instantiate model with empty weights
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
# Instantiate model with empty weights for faster init
|
||||
with accelerate.init_empty_weights():
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
# If the device_map is None, we load everything on the CPU lazily
|
||||
if device_map is None:
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
||||
if device_map is None and not is_sharded:
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
else:
|
||||
if not is_sharded:
|
||||
param_device = "cpu"
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
@@ -832,26 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
"mismatched_keys": [],
|
||||
"error_msgs": [],
|
||||
}
|
||||
else:
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file, variant=variant)
|
||||
model._convert_deprecated_attention_blocks(state_dict)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
model_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
|
||||
|
||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user