mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 10:54:34 +08:00
Compare commits
10 Commits
qwenimage-
...
progress-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4c096949af | ||
|
|
9cf5edd4db | ||
|
|
086a770174 | ||
|
|
91c73bb137 | ||
|
|
8193f38e19 | ||
|
|
2529fdf6af | ||
|
|
94d671d03a | ||
|
|
8df3fbcc14 | ||
|
|
97e380573c | ||
|
|
f207108643 |
@@ -47,6 +47,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
tqdm_kwargs["disable"] = True
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
||||
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
@@ -59,11 +59,8 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import (
|
||||
PushToHubMixin,
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
shard_files = resolved_model_file
|
||||
if len(resolved_model_file) > 1:
|
||||
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
shard_tqdm_kwargs["disable"] = True
|
||||
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
|
||||
|
||||
for shard_file in shard_files:
|
||||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
||||
|
||||
@@ -67,6 +67,7 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
|
||||
|
||||
@@ -982,7 +983,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
|
||||
if not is_torch_dist_rank_zero():
|
||||
logging_tqdm_kwargs["disable"] = True
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None:
|
||||
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
|
||||
@@ -1908,10 +1913,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
progress_bar_config = dict(self._progress_bar_config)
|
||||
if "disable" not in progress_bar_config:
|
||||
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
return tqdm(iterable, **progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
return tqdm(total=total, **progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
|
||||
36
src/diffusers/utils/distributed_utils.py
Normal file
36
src/diffusers/utils/distributed_utils.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
torch = None
|
||||
|
||||
|
||||
def is_torch_dist_rank_zero() -> bool:
|
||||
if torch is None:
|
||||
return True
|
||||
|
||||
dist_module = getattr(torch, "distributed", None)
|
||||
if dist_module is None or not dist_module.is_available():
|
||||
return True
|
||||
|
||||
if not dist_module.is_initialized():
|
||||
return True
|
||||
|
||||
try:
|
||||
return dist_module.get_rank() == 0
|
||||
except (RuntimeError, ValueError):
|
||||
return True
|
||||
@@ -32,6 +32,8 @@ from typing import Dict, Optional
|
||||
|
||||
from tqdm import auto as tqdm_lib
|
||||
|
||||
from .distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
_lock = threading.Lock()
|
||||
_default_handler: Optional[logging.Handler] = None
|
||||
@@ -47,6 +49,23 @@ log_levels = {
|
||||
_default_log_level = logging.WARNING
|
||||
|
||||
_tqdm_active = True
|
||||
_rank_zero_filter = None
|
||||
|
||||
|
||||
class _RankZeroFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
|
||||
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
|
||||
|
||||
|
||||
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
|
||||
global _rank_zero_filter
|
||||
|
||||
if _rank_zero_filter is None:
|
||||
_rank_zero_filter = _RankZeroFilter()
|
||||
|
||||
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
|
||||
logger.addFilter(_rank_zero_filter)
|
||||
|
||||
|
||||
def _get_default_logging_level() -> int:
|
||||
@@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
_ensure_rank_zero_filter(library_root_logger)
|
||||
|
||||
|
||||
def _reset_library_root_logger() -> None:
|
||||
@@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
logger = logging.getLogger(name)
|
||||
_ensure_rank_zero_filter(logger)
|
||||
return logger
|
||||
|
||||
|
||||
def get_verbosity() -> int:
|
||||
|
||||
Reference in New Issue
Block a user