Compare commits

...

22 Commits

Author SHA1 Message Date
sayakpaul
418211f9e7 merge 2025-06-19 18:44:10 +05:30
sayakpaul
33f30ef86e updates 2025-06-19 18:41:46 +05:30
sayakpaul
bcb71c9c8b Merge branch 'group-offloading-with-disk' into go-disk-nvme 2025-06-19 17:07:27 +05:30
Sayak Paul
68b07da7bd Merge branch 'main' into group-offloading-with-disk 2025-06-19 17:04:59 +05:30
sayakpaul
b535b99e13 nvme support 2025-06-19 17:00:04 +05:30
sayakpaul
9e9465646f add nvme save 2025-06-19 16:24:40 +05:30
sayakpaul
90e546ada1 update more docs. 2025-06-19 16:16:16 +05:30
sayakpaul
7d2295567f update todos. 2025-06-19 16:10:39 +05:30
sayakpaul
da11656af4 updates 2025-06-19 15:59:48 +05:30
sayakpaul
68f7580656 Merge branch 'main' into group-offloading-with-disk 2025-06-19 15:54:54 +05:30
Sayak Paul
357226327f Merge branch 'main' into group-offloading-with-disk 2025-06-19 09:29:44 +05:30
Sayak Paul
3c69c5ecc0 Merge branch 'main' into group-offloading-with-disk 2025-06-18 09:43:35 +05:30
Sayak Paul
2d3056180a Merge branch 'main' into group-offloading-with-disk 2025-06-14 07:31:38 +05:30
Sayak Paul
a018ee1e70 Merge branch 'main' into group-offloading-with-disk 2025-06-12 11:08:56 +05:30
sayakpaul
8029cd7ef0 add test and clarify. 2025-06-12 11:08:19 +05:30
sayakpaul
4e4842fb0b check if safetensors already exist. 2025-06-12 10:48:57 +05:30
sayakpaul
d8179b10d3 offload_to_disk_path 2025-06-12 10:27:48 +05:30
Sayak Paul
0bf55a99d3 Merge branch 'main' into group-offloading-with-disk 2025-06-10 10:29:54 +05:30
Sayak Paul
d32a2c6879 Merge branch 'main' into group-offloading-with-disk 2025-06-09 14:36:43 +05:30
sayakpaul
278cbc2e47 updates.patch 2025-06-09 12:03:01 +05:30
sayakpaul
49ac665460 delete diff file. 2025-06-09 10:26:00 +05:30
sayakpaul
e0d5079f9c start implementing disk offloading in group. 2025-06-09 10:25:45 +05:30
5 changed files with 1958 additions and 9 deletions

View File

@@ -20,6 +20,7 @@ import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils.import_utils import is_deepspeed_available, is_deepspeed_version
from .hooks import HookRegistry, ModelHook
@@ -27,6 +28,8 @@ if is_accelerate_available():
from accelerate.hooks import AlignDevicesHook, CpuOffload
from accelerate.utils import send_to_device
if is_deepspeed_available() and is_deepspeed_version(">=", "0.16"):
from ..utils.state_dict_utils import _fast_aio_save
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -62,6 +65,7 @@ class ModuleGroup:
low_cpu_mem_usage: bool = False,
onload_self: bool = True,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
self.modules = modules
self.offload_device = offload_device
@@ -80,7 +84,9 @@ class ModuleGroup:
self._is_offloaded_to_disk = False
if self.offload_to_disk_path:
self.safetensors_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}.safetensors")
self._enable_deepnvme_disk_offloading = _enable_deepnvme_disk_offloading
ext = ".pt" if _enable_deepnvme_disk_offloading else ".safetensors"
self.param_file_path = os.path.join(self.offload_to_disk_path, f"group_{id(self)}{ext}")
all_tensors = []
for module in self.modules:
@@ -153,8 +159,11 @@ class ModuleGroup:
with context:
if self.stream is not None:
# Load to CPU, pin, and async copy to device for overlapping transfer and compute
loaded_cpu_tensors = safetensors.torch.load_file(self.safetensors_file_path, device="cpu")
# Load to CPU from disk, pin, and async copy to device for overlapping transfer and compute
if self._enable_deepnvme_disk_offloading:
loaded_cpu_tensors = torch.load(self.param_file_path, weights_only=True, map_location="cpu")
else:
loaded_cpu_tensors = safetensors.torch.load_file(self.param_file_path, device="cpu")
for key, tensor_obj in self.key_to_tensor.items():
pinned_tensor = loaded_cpu_tensors[key].pin_memory()
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
@@ -165,7 +174,12 @@ class ModuleGroup:
onload_device = (
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
)
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=onload_device)
if self._enable_deepnvme_disk_offloading:
loaded_tensors = torch.load(
self.param_file_path, weights_only=True, map_location=onload_device
)
else:
loaded_tensors = safetensors.torch.load_file(self.param_file_path, device=onload_device)
for key, tensor_obj in self.key_to_tensor.items():
tensor_obj.data = loaded_tensors[key]
return
@@ -218,15 +232,18 @@ class ModuleGroup:
if self.offload_to_disk_path:
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
# overhead. Currently, we just check if the given `param_file_path` exists and if not
# we perform a write.
# Check if the file has been saved in this session or if it already exists on disk.
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
if not self._is_offloaded_to_disk and not os.path.exists(self.param_file_path):
os.makedirs(os.path.dirname(self.param_file_path), exist_ok=True)
tensors_to_save = {
key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()
}
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
if self._enable_deepnvme_disk_offloading:
_fast_aio_save(tensors_to_save, self.param_file_path)
else:
safetensors.torch.save_file(tensors_to_save, self.param_file_path)
# The group is now considered offloaded to disk for the rest of the session.
self._is_offloaded_to_disk = True
@@ -426,6 +443,7 @@ def apply_group_offloading(
record_stream: bool = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -484,6 +502,8 @@ def apply_group_offloading(
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
the CPU memory is a bottleneck but may counteract the benefits of using streams.
(TODO: include example with `offload_to_disk_path`)
Example:
```python
>>> from diffusers import CogVideoXTransformer3DModel
@@ -529,6 +549,7 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(
@@ -540,6 +561,7 @@ def apply_group_offloading(
stream=stream,
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
else:
raise ValueError(f"Unsupported offload_type: {offload_type}")
@@ -555,6 +577,7 @@ def _apply_group_offloading_block_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
@@ -615,6 +638,7 @@ def _apply_group_offloading_block_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
matched_module_groups.append(group)
for j in range(i, i + len(current_modules)):
@@ -649,6 +673,7 @@ def _apply_group_offloading_block_level(
stream=None,
record_stream=False,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
if stream is None:
_apply_group_offloading_hook(module, unmatched_group, None)
@@ -665,6 +690,7 @@ def _apply_group_offloading_leaf_level(
record_stream: Optional[bool] = False,
low_cpu_mem_usage: bool = False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
@@ -715,6 +741,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(submodule, group, None)
modules_with_group_offloading.add(name)
@@ -762,6 +789,7 @@ def _apply_group_offloading_leaf_level(
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -783,6 +811,7 @@ def _apply_group_offloading_leaf_level(
record_stream=False,
low_cpu_mem_usage=low_cpu_mem_usage,
onload_self=True,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)

View File

@@ -549,6 +549,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream: bool = False,
low_cpu_mem_usage=False,
offload_to_disk_path: Optional[str] = None,
_enable_deepnvme_disk_offloading: Optional[bool] = False,
) -> None:
r"""
Activates group offloading for the current model.
@@ -599,6 +600,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
record_stream=record_stream,
low_cpu_mem_usage=low_cpu_mem_usage,
offload_to_disk_path=offload_to_disk_path,
_enable_deepnvme_disk_offloading=_enable_deepnvme_disk_offloading,
)
def save_pretrained(

File diff suppressed because it is too large Load Diff

View File

@@ -220,6 +220,11 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
_nltk_available, _nltk_version = _is_package_available("nltk")
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
_deepspeed_available, _deepspeed_version = _is_package_available("deepspeed")
def is_deepspeed_available():
return _deepspeed_available
def is_torch_available():
@@ -655,6 +660,19 @@ def is_torch_version(operation: str, version: str):
return compare_versions(parse(_torch_version), operation, version)
def is_deepspeed_version(operation: str, version: str):
"""
Compares the current DeepSpeed version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A string version of DeepSpeed
"""
return compare_versions(parse(_deepspeed_version), operation, version)
def is_torch_xla_version(operation: str, version: str):
"""
Compares the current torch_xla version to a given reference with an operation.

View File

@@ -18,13 +18,19 @@ State dict utilities: utility methods for converting state dicts easily
import enum
import json
from .import_utils import is_torch_available
from .import_utils import is_deepspeed_available, is_deepspeed_version, is_torch_available
from .logging import get_logger
if is_torch_available():
import torch
if is_deepspeed_available() and is_deepspeed_version(">", "0.16"):
from deepspeed.io import FastFileWriter, FastFileWriterConfig
from deepspeed.ops.op_builder import AsyncIOBuilder, GDSBuilder
from .deep_nvme_utils import save as _nvme_save
logger = get_logger(__name__)
@@ -364,3 +370,54 @@ def _load_sft_state_dict_metadata(model_file: str):
return json.loads(raw) if raw else None
else:
return None
# Utilities below are taken from
# https://github.com/deepspeedai/DeepSpeedExamples/blob/28a984e77b8d096dadc6389b6d1440b823587e28/deepnvme/model_checkpoint/torch_save_utils.py#L16
def _load_io_ops(args):
if AsyncIOBuilder().is_compatible():
AsyncIOBuilder().load(verbose=False)
if args.gpu and GDSBuilder().is_compatible():
GDSBuilder().load(verbose=False)
def _get_aio_handle():
AIO_QUEUE_DEPTH = 8
AIO_BLOCK_SIZE = 8 * (1024**2)
AIO_INTRA_OP_PARALLEL = 1
AIO_SINGLE_SUBMIT = False
h = (
AsyncIOBuilder()
.load(verbose=False)
.aio_handle(
block_size=AIO_BLOCK_SIZE,
queue_depth=AIO_QUEUE_DEPTH,
single_submit=AIO_SINGLE_SUBMIT,
overlap_events=AIO_SINGLE_SUBMIT,
intra_op_parallelism=AIO_INTRA_OP_PARALLEL,
)
)
return h
def _get_aio_components():
PINNED_BUFFER_MB = 64
h = _get_aio_handle()
pinned_memory = torch.zeros(PINNED_BUFFER_MB * (1024**2), dtype=torch.uint8, device="cpu").pin_memory()
return h, pinned_memory
def _fast_aio_save(buffer, file, single_io_buffer=False):
h, pinned_memory = _get_aio_components()
fast_writer_config = FastFileWriterConfig(
dnvme_handle=h,
pinned_tensor=pinned_memory,
double_buffer=not single_io_buffer,
num_parallel_writers=1,
writer_rank=0,
)
ds_fast_writer = FastFileWriter(file_path=file, config=fast_writer_config)
_nvme_save(f=ds_fast_writer, obj=buffer, _use_new_zipfile_serialization=False)
ds_fast_writer.close()