Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
9de5739fb4 Merge branch 'main' into disable-mmap-pipeline 2025-12-05 19:06:54 +05:30
DN6
287331d9c1 update 2025-11-17 23:29:18 +05:30
2 changed files with 7 additions and 1 deletions

View File

@@ -354,8 +354,9 @@ def _load_shard_file(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
disable_mmap=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
@@ -401,6 +402,7 @@ def _load_shard_files_with_threadpool(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
disable_mmap=False,
):
# Do not spawn anymore workers than you need
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
@@ -427,6 +429,7 @@ def _load_shard_files_with_threadpool(
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
disable_mmap=disable_mmap,
)
with ThreadPoolExecutor(max_workers=num_workers) as executor:

View File

@@ -1309,6 +1309,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
keep_in_fp32_modules=keep_in_fp32_modules,
dduf_entries=dduf_entries,
is_parallel_loading_enabled=is_parallel_loading_enabled,
disable_mmap=disable_mmap,
)
loading_info = {
"missing_keys": missing_keys,
@@ -1595,6 +1596,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
offload_folder: Optional[Union[str, os.PathLike]] = None,
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
is_parallel_loading_enabled: Optional[bool] = False,
disable_mmap: bool = False,
):
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
@@ -1663,6 +1665,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
disable_mmap=disable_mmap,
)
if is_parallel_loading_enabled: