Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
51d1436a16 faster model loading on cuda. 2025-04-21 18:05:39 +05:30
4 changed files with 85 additions and 1 deletions

View File

@@ -18,7 +18,7 @@ import importlib
import inspect
import os
from array import array
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
from zipfile import is_zipfile
@@ -38,6 +38,7 @@ from ..utils import (
_get_model_file,
deprecate,
is_accelerate_available,
is_accelerator_device,
is_gguf_available,
is_torch_available,
is_torch_version,
@@ -304,6 +305,51 @@ def load_model_dict_into_meta(
return offload_index, state_dict_index
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5852C1-L5861C26
def _expand_device_map(device_map, param_names):
new_device_map = {}
for module, device in device_map.items():
new_device_map.update(
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
)
return new_device_map
# Adapted from https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5874
# We don't incorporate the `tp_plan` stuff as we don't support it yet.
def _caching_allocator_warmup(model, device_map: Dict, factor=2) -> Dict:
# Remove disk, cpu and meta devices, and cast to proper torch.device
accelerator_device_map = {
param: torch.device(device) for param, device in device_map.items() if is_accelerator_device(device)
}
if not len(accelerator_device_map):
return
total_byte_count = defaultdict(lambda: 0)
for param_name, device in accelerator_device_map.items():
param = model.get_parameter_or_buffer(param_name)
# The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
param_byte_count = param.numel() * param.element_size()
total_byte_count[device] += param_byte_count
# This will kick off the caching allocator to avoid having to Malloc afterwards
for device, byte_count in total_byte_count.items():
if device.type == "cuda":
index = device.index if device.index is not None else torch.cuda.current_device()
device_memory = torch.cuda.mem_get_info(index)[0]
# Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
# than that amount might sometimes lead to unecesary cuda OOM, if the last parameter to be loaded on the device is large,
# and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
# the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
# to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
# Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
# if using e.g. 90% of device size, while a 140GiB device would allocate too little
byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
# Allocate memory
_ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
def _load_state_dict_into_model(
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
) -> List[str]:

View File

@@ -63,7 +63,9 @@ from ..utils.hub_utils import (
populate_model_card,
)
from .model_loading_utils import (
_caching_allocator_warmup,
_determine_device_map,
_expand_device_map,
_fetch_index_file,
_fetch_index_file_legacy,
_load_state_dict_into_model,
@@ -1374,6 +1376,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
return super().float(*args)
# Taken from `transformers`.
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5351C5-L5365C81
def get_parameter_or_buffer(self, target: str):
"""
Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
`get_parameter()` and `get_buffer()` in a single handy function. Note that it only work if `target` is a leaf
of the model.
"""
try:
return self.get_parameter(target)
except AttributeError:
pass
try:
return self.get_buffer(target)
except AttributeError:
pass
raise AttributeError(f"`{target}` is neither a parameter nor a buffer.")
@classmethod
def _load_pretrained_model(
cls,
@@ -1410,6 +1430,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
assign_to_params_buffers = None
error_msgs = []
# Optionally, warmup cuda to load the weights much faster on devices
if device_map is not None:
expanded_device_map = _expand_device_map(device_map, expected_keys)
_caching_allocator_warmup(model, expanded_device_map, factor=2 if hf_quantizer is None else 4)
# Deal with offload
if device_map is not None and "disk" in device_map.values():
if offload_folder is None:

View File

@@ -129,6 +129,7 @@ from .state_dict_utils import (
convert_unet_state_dict_to_peft,
state_dict_all_zero,
)
from .testing_utils import is_accelerator_device
from .typing_utils import _get_detailed_type, _is_valid_type

View File

@@ -1289,6 +1289,18 @@ if is_torch_available():
update_mapping_from_spec(BACKEND_MAX_MEMORY_ALLOCATED, "MAX_MEMORY_ALLOCATED_FN")
if is_torch_available():
# Taken from
# https://github.com/huggingface/transformers/blob/6daa3eeba582facb57cd71db8efb66998b12942f/src/transformers/modeling_utils.py#L5864C1-L5871C64
def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
"""Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
a proper `torch.device`.
"""
if device == "disk":
return False
else:
return torch.device(device).type not in ["meta", "cpu"]
# Modified from https://github.com/huggingface/transformers/blob/cdfb018d0300fef3b07d9220f3efe9c2a9974662/src/transformers/testing_utils.py#L3090
# Type definition of key used in `Expectations` class.