mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-14 16:34:27 +08:00
Compare commits
1 Commits
enable-cp-
...
muellerzr-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ea44f69be3 |
@@ -165,20 +165,10 @@ def load_model_dict_into_meta(
|
|||||||
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
||||||
# Convert old format to new format if needed from a PyTorch state_dict
|
# Convert old format to new format if needed from a PyTorch state_dict
|
||||||
# copy state_dict so _load_from_state_dict can modify it
|
# copy state_dict so _load_from_state_dict can modify it
|
||||||
state_dict = state_dict.copy()
|
# state_dict = state_dict.copy()
|
||||||
error_msgs = []
|
error_msgs = []
|
||||||
|
|
||||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
model_to_load.load_state_dict(state_dict, assign=True, strict=False)
|
||||||
# so we need to apply the function recursively.
|
|
||||||
def load(module: torch.nn.Module, prefix: str = ""):
|
|
||||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
|
||||||
module._load_from_state_dict(*args)
|
|
||||||
|
|
||||||
for name, child in module._modules.items():
|
|
||||||
if child is not None:
|
|
||||||
load(child, prefix + name + ".")
|
|
||||||
|
|
||||||
load(model_to_load)
|
|
||||||
|
|
||||||
return error_msgs
|
return error_msgs
|
||||||
|
|
||||||
|
|||||||
@@ -734,13 +734,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
commit_hash=commit_hash,
|
commit_hash=commit_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if low_cpu_mem_usage:
|
# Instantiate model with empty weights for faster init
|
||||||
# Instantiate model with empty weights
|
with accelerate.init_empty_weights():
|
||||||
with accelerate.init_empty_weights():
|
model = cls.from_config(config, **unused_kwargs)
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
# If the device_map is None, we load everything on the CPU lazily
|
||||||
|
if device_map is None:
|
||||||
|
state_dict = load_state_dict(model_file, variant=variant)
|
||||||
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
|
|
||||||
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||||
if device_map is None and not is_sharded:
|
model,
|
||||||
|
state_dict,
|
||||||
|
model_file,
|
||||||
|
pretrained_model_name_or_path,
|
||||||
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
|
)
|
||||||
|
|
||||||
|
loading_info = {
|
||||||
|
"missing_keys": missing_keys,
|
||||||
|
"unexpected_keys": unexpected_keys,
|
||||||
|
"mismatched_keys": mismatched_keys,
|
||||||
|
"error_msgs": error_msgs,
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
if not is_sharded:
|
||||||
param_device = "cpu"
|
param_device = "cpu"
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
state_dict = load_state_dict(model_file, variant=variant)
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
model._convert_deprecated_attention_blocks(state_dict)
|
||||||
@@ -832,26 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
"mismatched_keys": [],
|
"mismatched_keys": [],
|
||||||
"error_msgs": [],
|
"error_msgs": [],
|
||||||
}
|
}
|
||||||
else:
|
|
||||||
model = cls.from_config(config, **unused_kwargs)
|
|
||||||
|
|
||||||
state_dict = load_state_dict(model_file, variant=variant)
|
|
||||||
model._convert_deprecated_attention_blocks(state_dict)
|
|
||||||
|
|
||||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
|
||||||
model,
|
|
||||||
state_dict,
|
|
||||||
model_file,
|
|
||||||
pretrained_model_name_or_path,
|
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
|
||||||
)
|
|
||||||
|
|
||||||
loading_info = {
|
|
||||||
"missing_keys": missing_keys,
|
|
||||||
"unexpected_keys": unexpected_keys,
|
|
||||||
"mismatched_keys": mismatched_keys,
|
|
||||||
"error_msgs": error_msgs,
|
|
||||||
}
|
|
||||||
|
|
||||||
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
Reference in New Issue
Block a user