Compare commits

...

1 Commits

Author SHA1 Message Date
[[ -z $EMAIL ]] && read -e -p "Enter your email (for git configuration): " EMAIL
ea44f69be3 checkpoint 2024-07-10 07:31:51 -04:00
2 changed files with 27 additions and 38 deletions

View File

@@ -165,20 +165,10 @@ def load_model_dict_into_meta(
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
# Convert old format to new format if needed from a PyTorch state_dict
# copy state_dict so _load_from_state_dict can modify it
state_dict = state_dict.copy()
# state_dict = state_dict.copy()
error_msgs = []
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
# so we need to apply the function recursively.
def load(module: torch.nn.Module, prefix: str = ""):
args = (state_dict, prefix, {}, True, [], [], error_msgs)
module._load_from_state_dict(*args)
for name, child in module._modules.items():
if child is not None:
load(child, prefix + name + ".")
load(model_to_load)
model_to_load.load_state_dict(state_dict, assign=True, strict=False)
return error_msgs

View File

@@ -734,13 +734,31 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
commit_hash=commit_hash,
)
if low_cpu_mem_usage:
# Instantiate model with empty weights
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
# Instantiate model with empty weights for faster init
with accelerate.init_empty_weights():
model = cls.from_config(config, **unused_kwargs)
# If the device_map is None, we load everything on the CPU lazily
if device_map is None:
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
# if device_map is None, load the state dict and move the params from meta device to the cpu
if device_map is None and not is_sharded:
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
else:
if not is_sharded:
param_device = "cpu"
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
@@ -832,26 +850,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
"mismatched_keys": [],
"error_msgs": [],
}
else:
model = cls.from_config(config, **unused_kwargs)
state_dict = load_state_dict(model_file, variant=variant)
model._convert_deprecated_attention_blocks(state_dict)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
model_file,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
)
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
raise ValueError(