mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 19:04:49 +08:00
Compare commits
10 Commits
qwenimage-
...
progress-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c096949af | ||
|
|
9cf5edd4db | ||
|
|
086a770174 | ||
|
|
91c73bb137 | ||
|
|
8193f38e19 | ||
|
|
2529fdf6af | ||
|
|
94d671d03a | ||
|
|
8df3fbcc14 | ||
|
|
97e380573c | ||
|
|
f207108643 |
@@ -47,6 +47,7 @@ from ..utils import (
|
|||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
|
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
|
|||||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||||
|
if not is_torch_dist_rank_zero():
|
||||||
|
tqdm_kwargs["disable"] = True
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||||
for future in as_completed(futures):
|
for future in as_completed(futures):
|
||||||
result = future.result()
|
result = future.result()
|
||||||
|
|||||||
@@ -59,11 +59,8 @@ from ..utils import (
|
|||||||
is_torch_version,
|
is_torch_version,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from ..utils.hub_utils import (
|
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||||
PushToHubMixin,
|
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||||
load_or_create_model_card,
|
|
||||||
populate_model_card,
|
|
||||||
)
|
|
||||||
from ..utils.torch_utils import empty_device_cache
|
from ..utils.torch_utils import empty_device_cache
|
||||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||||
from .model_loading_utils import (
|
from .model_loading_utils import (
|
||||||
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|||||||
else:
|
else:
|
||||||
shard_files = resolved_model_file
|
shard_files = resolved_model_file
|
||||||
if len(resolved_model_file) > 1:
|
if len(resolved_model_file) > 1:
|
||||||
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
|
||||||
|
if not is_torch_dist_rank_zero():
|
||||||
|
shard_tqdm_kwargs["disable"] = True
|
||||||
|
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
|
||||||
|
|
||||||
for shard_file in shard_files:
|
for shard_file in shard_files:
|
||||||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ from ..utils import (
|
|||||||
logging,
|
logging,
|
||||||
numpy_to_pil,
|
numpy_to_pil,
|
||||||
)
|
)
|
||||||
|
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
||||||
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
|
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
|
||||||
|
|
||||||
@@ -982,7 +983,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
# 7. Load each module in the pipeline
|
# 7. Load each module in the pipeline
|
||||||
current_device_map = None
|
current_device_map = None
|
||||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
|
||||||
|
if not is_torch_dist_rank_zero():
|
||||||
|
logging_tqdm_kwargs["disable"] = True
|
||||||
|
|
||||||
|
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
|
||||||
# 7.1 device_map shenanigans
|
# 7.1 device_map shenanigans
|
||||||
if final_device_map is not None:
|
if final_device_map is not None:
|
||||||
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
|
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
|
||||||
@@ -1908,10 +1913,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|||||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
progress_bar_config = dict(self._progress_bar_config)
|
||||||
|
if "disable" not in progress_bar_config:
|
||||||
|
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
|
||||||
|
|
||||||
if iterable is not None:
|
if iterable is not None:
|
||||||
return tqdm(iterable, **self._progress_bar_config)
|
return tqdm(iterable, **progress_bar_config)
|
||||||
elif total is not None:
|
elif total is not None:
|
||||||
return tqdm(total=total, **self._progress_bar_config)
|
return tqdm(total=total, **progress_bar_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||||
|
|
||||||
|
|||||||
36
src/diffusers/utils/distributed_utils.py
Normal file
36
src/diffusers/utils/distributed_utils.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Copyright 2025 The HuggingFace Inc. team.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
torch = None
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_dist_rank_zero() -> bool:
|
||||||
|
if torch is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
dist_module = getattr(torch, "distributed", None)
|
||||||
|
if dist_module is None or not dist_module.is_available():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not dist_module.is_initialized():
|
||||||
|
return True
|
||||||
|
|
||||||
|
try:
|
||||||
|
return dist_module.get_rank() == 0
|
||||||
|
except (RuntimeError, ValueError):
|
||||||
|
return True
|
||||||
@@ -32,6 +32,8 @@ from typing import Dict, Optional
|
|||||||
|
|
||||||
from tqdm import auto as tqdm_lib
|
from tqdm import auto as tqdm_lib
|
||||||
|
|
||||||
|
from .distributed_utils import is_torch_dist_rank_zero
|
||||||
|
|
||||||
|
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
_default_handler: Optional[logging.Handler] = None
|
_default_handler: Optional[logging.Handler] = None
|
||||||
@@ -47,6 +49,23 @@ log_levels = {
|
|||||||
_default_log_level = logging.WARNING
|
_default_log_level = logging.WARNING
|
||||||
|
|
||||||
_tqdm_active = True
|
_tqdm_active = True
|
||||||
|
_rank_zero_filter = None
|
||||||
|
|
||||||
|
|
||||||
|
class _RankZeroFilter(logging.Filter):
|
||||||
|
def filter(self, record):
|
||||||
|
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
|
||||||
|
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
|
||||||
|
global _rank_zero_filter
|
||||||
|
|
||||||
|
if _rank_zero_filter is None:
|
||||||
|
_rank_zero_filter = _RankZeroFilter()
|
||||||
|
|
||||||
|
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
|
||||||
|
logger.addFilter(_rank_zero_filter)
|
||||||
|
|
||||||
|
|
||||||
def _get_default_logging_level() -> int:
|
def _get_default_logging_level() -> int:
|
||||||
@@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.addHandler(_default_handler)
|
library_root_logger.addHandler(_default_handler)
|
||||||
library_root_logger.setLevel(_get_default_logging_level())
|
library_root_logger.setLevel(_get_default_logging_level())
|
||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
_ensure_rank_zero_filter(library_root_logger)
|
||||||
|
|
||||||
|
|
||||||
def _reset_library_root_logger() -> None:
|
def _reset_library_root_logger() -> None:
|
||||||
@@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
|||||||
name = _get_library_name()
|
name = _get_library_name()
|
||||||
|
|
||||||
_configure_library_root_logger()
|
_configure_library_root_logger()
|
||||||
return logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
|
_ensure_rank_zero_filter(logger)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
def get_verbosity() -> int:
|
def get_verbosity() -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user