mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
1 Commits
higgs
...
faster-loa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51d1436a16 |
@@ -18,7 +18,7 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
from array import array
|
||||
from collections import OrderedDict
|
||||
from collections import OrderedDict, defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from zipfile import is_zipfile
|
||||
@@ -38,6 +38,7 @@ from ..utils import (
|
||||
_get_model_file,
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerator_device,
|
||||
is_gguf_available,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
@@ -304,6 +305,51 @@ def load_model_dict_into_meta(
|
||||
return offload_index, state_dict_index
|
||||
|
||||
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
|
||||
def _expand_device_map(device_map, param_names):
|
||||
new_device_map = {}
|
||||
for module, device in device_map.items():
|
||||
new_device_map.update(
|
||||
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
||||
)
|
||||
return new_device_map
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
|
||||
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
|
||||
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
|
||||
# Remove disk, cpu and meta devices, and cast to proper torch.device
|
||||
accelerator_device_map = {
|
||||
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
|
||||
}
|
||||
if not len(accelerator_device_map):
|
||||
return
|
||||
|
||||
total_byte_count = defaultdict(lambda: 0)
|
||||
for param_name, device in accelerator_device_map.items():
|
||||
param = model.get_parameter_or_buffer(param_name)
|
||||
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
|
||||
param_byte_count = param.numel() * param.element_size()
|
||||
total_byte_count[device] += param_byte_count
|
||||
|
||||
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
||||
for device, byte_count in total_byte_count.items():
|
||||
if device.type == "cuda":
|
||||
index = device.index if device.index is not None else torch.cuda.current_device()
|
||||
device_memory = torch.cuda.mem_get_info(index)[0]
|
||||
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
|
||||
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
|
||||
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
|
||||
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
|
||||
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
|
||||
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
|
||||
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
|
||||
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
|
||||
# Allocate memory
|
||||
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(
|
||||
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
||||
) -> List[str]:
|
||||
|
||||
@@ -63,7 +63,9 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
_expand_device_map,
|
||||
_fetch_index_file,
|
||||
_fetch_index_file_legacy,
|
||||
_load_state_dict_into_model,
|
||||
@@ -1374,6 +1376,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
return super().float(*args)
|
||||
|
||||
# Taken from `transformers`.
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
|
||||
def get_parameter_or_buffer(self, target: str):
|
||||
"""
|
||||
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
|
||||
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
|
||||
of the model.
|
||||
"""
|
||||
try:
|
||||
return self.get_parameter(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
try:
|
||||
return self.get_buffer(target)
|
||||
except AttributeError:
|
||||
pass
|
||||
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -1410,6 +1430,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
assign_to_params_buffers = None
|
||||
error_msgs = []
|
||||
|
||||
# Optionally, warmup cuda to load the weights much faster on devices
|
||||
if device_map is not None:
|
||||
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
||||
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
|
||||
|
||||
# Deal with offload
|
||||
if device_map is not None and "disk" in device_map.values():
|
||||
if offload_folder is None:
|
||||
|
||||
@@ -129,6 +129,7 @@ from .state_dict_utils import (
|
||||
convert_unet_state_dict_to_peft,
|
||||
state_dict_all_zero,
|
||||
)
|
||||
from .testing_utils import is_accelerator_device
|
||||
from .typing_utils import _get_detailed_type, _is_valid_type
|
||||
|
||||
|
||||
|
||||
@@ -1289,6 +1289,18 @@ if is_torch_available():
|
||||
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
# Taken from
|
||||
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
|
||||
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
|
||||
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
|
||||
a proper `torch.device`.
|
||||
"""
|
||||
if device == "disk":
|
||||
return False
|
||||
else:
|
||||
return torch.device(device).type not in ["meta", "cpu"]
|
||||
|
||||
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
|
||||
|
||||
# Type definition of key used in `Expectations` class.
|
||||
|
||||
Reference in New Issue
Block a user