mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-31 21:06:45 +08:00
Compare commits
13 Commits
autoencode
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f60afe5cba | ||
|
|
06509796dd | ||
|
|
59c1b2534a | ||
|
|
7eaeb99fcd | ||
|
|
867192364c | ||
|
|
a8cef0740a | ||
|
|
70067734a2 | ||
|
|
6125a4f540 | ||
|
|
d2666a9d0a | ||
|
|
9b9e2e17a6 | ||
|
|
1a959dc26f | ||
|
|
8797398d3b | ||
|
|
019a9deafb |
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,6 +35,54 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -124,6 +172,13 @@ class ModuleGroup:
|
||||
else torch.cuda
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_cpu(tensor, low_cpu_mem_usage):
|
||||
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
|
||||
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
|
||||
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
|
||||
return t if low_cpu_mem_usage else t.pin_memory()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -131,17 +186,15 @@ class ModuleGroup:
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@@ -157,9 +210,16 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
if self.record_stream:
|
||||
tensor.data.record_stream(default_stream)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -178,7 +238,19 @@ class ModuleGroup:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, default_stream)
|
||||
|
||||
def _check_disk_offload_torchao(self):
|
||||
all_tensors = list(self.tensor_to_key.keys())
|
||||
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
|
||||
if has_torchao:
|
||||
raise ValueError(
|
||||
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
|
||||
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
|
||||
"setting `offload_to_disk_path`."
|
||||
)
|
||||
|
||||
def _onload_from_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
@@ -221,6 +293,8 @@ class ModuleGroup:
|
||||
self._process_tensors_from_modules(None)
|
||||
|
||||
def _offload_to_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
@@ -245,18 +319,35 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
Reference in New Issue
Block a user