Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
f60afe5cba error out for the offload to disk option. 2026-03-30 13:19:12 +05:30
Sayak Paul
06509796dd Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:48:22 +05:30
Sayak Paul
59c1b2534a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:24:51 +05:30
sayakpaul
7eaeb99fcd address feedback. 2026-03-30 11:24:40 +05:30
Sayak Paul
867192364c Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 10:53:47 +05:30
Sayak Paul
a8cef0740a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-27 21:16:15 +05:30
Sayak Paul
70067734a2 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-26 11:29:51 +05:30
Sayak Paul
6125a4f540 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-25 08:07:01 +05:30
Sayak Paul
d2666a9d0a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-24 09:06:42 +05:30
sayakpaul
9b9e2e17a6 up 2026-03-23 11:22:36 +05:30
sayakpaul
1a959dc26f switch to swap_tensors. 2026-03-23 10:56:16 +05:30
Sayak Paul
8797398d3b Merge branch 'main' into fix-torchao-groupoffloading 2026-03-23 09:05:37 +05:30
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,54 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -124,6 +172,13 @@ class ModuleGroup:
else torch.cuda
)
@staticmethod
def _to_cpu(tensor, low_cpu_mem_usage):
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -131,17 +186,15 @@ class ModuleGroup:
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
return cpu_param_dict
@@ -157,9 +210,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -178,7 +238,19 @@ class ModuleGroup:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)
def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
)
def _onload_from_disk(self):
self._check_disk_offload_torchao()
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -221,6 +293,8 @@ class ModuleGroup:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
self._check_disk_offload_torchao()
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
@@ -245,18 +319,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):