mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-22 12:24:39 +08:00
Compare commits
2 Commits
cp-fixes-a
...
disable-mm
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9de5739fb4 | ||
|
|
287331d9c1 |
@@ -354,8 +354,9 @@ def _load_shard_file(
|
|||||||
state_dict_folder=None,
|
state_dict_folder=None,
|
||||||
ignore_mismatched_sizes=False,
|
ignore_mismatched_sizes=False,
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
|
disable_mmap=False,
|
||||||
):
|
):
|
||||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
|
||||||
mismatched_keys = _find_mismatched_keys(
|
mismatched_keys = _find_mismatched_keys(
|
||||||
state_dict,
|
state_dict,
|
||||||
model_state_dict,
|
model_state_dict,
|
||||||
@@ -401,6 +402,7 @@ def _load_shard_files_with_threadpool(
|
|||||||
state_dict_folder=None,
|
state_dict_folder=None,
|
||||||
ignore_mismatched_sizes=False,
|
ignore_mismatched_sizes=False,
|
||||||
low_cpu_mem_usage=False,
|
low_cpu_mem_usage=False,
|
||||||
|
disable_mmap=False,
|
||||||
):
|
):
|
||||||
# Do not spawn anymore workers than you need
|
# Do not spawn anymore workers than you need
|
||||||
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
||||||
@@ -427,6 +429,7 @@ def _load_shard_files_with_threadpool(
|
|||||||
state_dict_folder=state_dict_folder,
|
state_dict_folder=state_dict_folder,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
|
|||||||
@@ -1309,6 +1309,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
keep_in_fp32_modules=keep_in_fp32_modules,
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
||||||
dduf_entries=dduf_entries,
|
dduf_entries=dduf_entries,
|
||||||
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
)
|
)
|
||||||
loading_info = {
|
loading_info = {
|
||||||
"missing_keys": missing_keys,
|
"missing_keys": missing_keys,
|
||||||
@@ -1595,6 +1596,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
||||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
||||||
is_parallel_loading_enabled: Optional[bool] = False,
|
is_parallel_loading_enabled: Optional[bool] = False,
|
||||||
|
disable_mmap: bool = False,
|
||||||
):
|
):
|
||||||
model_state_dict = model.state_dict()
|
model_state_dict = model.state_dict()
|
||||||
expected_keys = list(model_state_dict.keys())
|
expected_keys = list(model_state_dict.keys())
|
||||||
@@ -1663,6 +1665,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
state_dict_folder=state_dict_folder,
|
state_dict_folder=state_dict_folder,
|
||||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
|
disable_mmap=disable_mmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_parallel_loading_enabled:
|
if is_parallel_loading_enabled:
|
||||||
|
|||||||
Reference in New Issue
Block a user