Compare commits

...

10 Commits

Author SHA1 Message Date
Sayak Paul
4c096949af Merge branch 'main' into progress-bar-dist 2025-12-11 13:11:39 +08:00
Sayak Paul
9cf5edd4db Merge branch 'main' into progress-bar-dist 2025-12-10 11:24:44 +08:00
sayakpaul
086a770174 up 2025-12-10 08:45:42 +05:30
Sayak Paul
91c73bb137 Merge branch 'main' into progress-bar-dist 2025-12-09 11:09:48 +08:00
sayakpaul
8193f38e19 up 2025-12-09 08:37:06 +05:30
sayakpaul
2529fdf6af up 2025-12-09 08:31:10 +05:30
sayakpaul
94d671d03a up 2025-12-08 19:01:22 +05:30
sayakpaul
8df3fbcc14 up 2025-12-08 18:47:26 +05:30
sayakpaul
97e380573c up 2025-12-08 18:37:52 +05:30
sayakpaul
f207108643 disable progressbar in distributed. 2025-12-08 18:13:38 +05:30
5 changed files with 83 additions and 11 deletions

View File

@@ -47,6 +47,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
logger = logging.get_logger(__name__)
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
low_cpu_mem_usage=low_cpu_mem_usage,
)
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
tqdm_kwargs["disable"] = True
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
with logging.tqdm(**tqdm_kwargs) as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()

View File

@@ -59,11 +59,8 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.hub_utils import (
PushToHubMixin,
load_or_create_model_card,
populate_model_card,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
shard_tqdm_kwargs["disable"] = True
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
for shard_file in shard_files:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)

View File

@@ -67,6 +67,7 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
@@ -982,7 +983,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
if not is_torch_dist_rank_zero():
logging_tqdm_kwargs["disable"] = True
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
# 7.1 device_map shenanigans
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1908,10 +1913,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
return tqdm(iterable, **progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
return tqdm(total=total, **progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

View File

@@ -0,0 +1,36 @@
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
try:
import torch
except ImportError:
torch = None
def is_torch_dist_rank_zero() -> bool:
if torch is None:
return True
dist_module = getattr(torch, "distributed", None)
if dist_module is None or not dist_module.is_available():
return True
if not dist_module.is_initialized():
return True
try:
return dist_module.get_rank() == 0
except (RuntimeError, ValueError):
return True

View File

@@ -32,6 +32,8 @@ from typing import Dict, Optional
from tqdm import auto as tqdm_lib
from .distributed_utils import is_torch_dist_rank_zero
_lock = threading.Lock()
_default_handler: Optional[logging.Handler] = None
@@ -47,6 +49,23 @@ log_levels = {
_default_log_level = logging.WARNING
_tqdm_active = True
_rank_zero_filter = None
class _RankZeroFilter(logging.Filter):
def filter(self, record):
# Always allow rank-zero logs, but keep debug-level messages from all ranks for troubleshooting.
return is_torch_dist_rank_zero() or record.levelno <= logging.DEBUG
def _ensure_rank_zero_filter(logger: logging.Logger) -> None:
global _rank_zero_filter
if _rank_zero_filter is None:
_rank_zero_filter = _RankZeroFilter()
if not any(isinstance(f, _RankZeroFilter) for f in logger.filters):
logger.addFilter(_rank_zero_filter)
def _get_default_logging_level() -> int:
@@ -90,6 +109,7 @@ def _configure_library_root_logger() -> None:
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
_ensure_rank_zero_filter(library_root_logger)
def _reset_library_root_logger() -> None:
@@ -120,7 +140,9 @@ def get_logger(name: Optional[str] = None) -> logging.Logger:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)
logger = logging.getLogger(name)
_ensure_rank_zero_filter(logger)
return logger
def get_verbosity() -> int: