Compare commits

...

12 Commits

Author SHA1 Message Date
DN6
d9915a7d65 update 2025-03-12 11:44:40 +05:30
DN6
b7a795dbeb update 2025-03-12 11:40:40 +05:30
DN6
438905d63e update 2025-03-12 11:37:27 +05:30
DN6
904f24de5a update 2025-03-12 11:35:18 +05:30
DN6
e123bbcbc4 memmap 2025-03-12 11:23:14 +05:30
DN6
b3fa8c695d remove cpu param dict 2025-03-12 09:02:04 +05:30
DN6
720be2bac5 update 2025-03-12 08:49:45 +05:30
DN6
e74b782aac update 2025-03-12 08:45:09 +05:30
DN6
d6392b4b49 update 2025-03-12 08:18:19 +05:30
DN6
1475026960 sliding-window 2025-03-11 13:56:39 +05:30
DN6
878eb4ce35 update 2025-03-11 13:21:09 +05:30
Dhruv Nair
9add071592 [Quantization] Allow loading TorchAO serialized Tensor objects with torch>=2.6 (#11018)
* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-03-11 10:52:01 +05:30
6 changed files with 202 additions and 49 deletions

View File

@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
image.save("output.png")
```
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
```python
import torch

View File

@@ -2,20 +2,14 @@ __version__ = "0.33.0.dev0"
from typing import TYPE_CHECKING
from diffusers.quantizers import quantization_config
from diffusers.utils import dummy_gguf_objects
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_optimum_quanto_version,
is_torchao_available,
)
from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
is_accelerate_available,
is_bitsandbytes_available,
is_flax_available,
is_gguf_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
@@ -24,6 +18,7 @@ from .utils import (
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
is_torchao_available,
is_torchsde_available,
is_transformers_available,
)
@@ -65,7 +60,7 @@ _import_structure = {
}
try:
if not is_bitsandbytes_available():
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects
@@ -77,7 +72,7 @@ else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
try:
if not is_gguf_available():
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects
@@ -89,7 +84,7 @@ else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
try:
if not is_torchao_available():
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects
@@ -101,7 +96,7 @@ else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
try:
if not is_optimum_quanto_available():
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects
@@ -112,7 +107,6 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()

View File

@@ -29,11 +29,16 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
# Always use memory-efficient CPU offloading to minimize RAM usage
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -56,7 +61,6 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -68,12 +72,8 @@ class ModuleGroup:
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
@@ -82,23 +82,125 @@ class ModuleGroup:
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
# Use the most efficient module-level transfer when possible
# This approach mirrors how PyTorch handles full model transfers
if self.modules:
for group_module in self.modules:
# Only onload if some parameters are not on the target device
if any(p.device != self.onload_device for p in group_module.parameters()):
try:
# Try the most efficient approach using _apply
if hasattr(group_module, "_apply"):
# This is what module.to() uses internally
def to_device(t):
if t.device != self.onload_device:
if self.onload_device.type == "cuda":
return t.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
return t.to(self.onload_device,
non_blocking=self.non_blocking)
return t
# Apply to all tensors without unnecessary copies
group_module._apply(to_device)
else:
# Fallback to direct parameter transfer
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
except Exception as e:
# If optimization fails, fall back to direct parameter transfer
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle explicit parameters
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if buffer.device != self.onload_device:
if self.onload_device.type == "cuda":
buffer.data = buffer.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
buffer.data = buffer.data.to(self.onload_device,
non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
# For CPU offloading
if self.offload_device.type == "cpu":
# Synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# Empty GPU cache before offloading to reduce memory fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For module groups, use a single, unified approach that is closest to
# the behavior of model.to("cpu")
if self.modules:
for group_module in self.modules:
# Check if we need to offload this module
if any(p.device.type != "cpu" for p in group_module.parameters()):
# Use PyTorch's built-in to() method directly, which preserves
# memory mapping when moving to CPU
try:
# Non-blocking=False for CPU transfers, as it ensures memory is
# immediately available and potentially preserves memory mapping
group_module.to("cpu", non_blocking=False)
except Exception as e:
# If there's any error, fall back to parameter-level offloading
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
for param in group_module.parameters():
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle explicit parameters - move directly to CPU with non-blocking=False
# which can preserve memory mapping in some PyTorch versions
if self.parameters is not None:
for param in self.parameters:
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
if buffer.device.type != "cpu":
buffer.data = buffer.data.to("cpu", non_blocking=False)
# Let Python's normal reference counting handle cleanup
# We don't force garbage collection to avoid slowing down inference
else:
# For non-CPU offloading, synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# For non-CPU offloading, use the regular approach
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
@@ -108,6 +210,9 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
# After offloading, we can unpin the memory if configured to do so
# We'll keep it pinned by default for better performance
class GroupOffloadingHook(ModelHook):
r"""
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# Offload to CPU
self.group.offload_()
return module
@@ -313,7 +419,8 @@ def apply_group_offloading(
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
Example:
```python
@@ -344,12 +451,19 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
# We no longer need a pinned group manager as we're not using pinned memory
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module,
num_blocks_per_group,
offload_device,
onload_device,
non_blocking,
stream,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)

View File

@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
from packaging import version
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
from ...utils import (
get_module_from_name,
is_torch_available,
is_torch_version,
is_torchao_available,
is_torchao_version,
logging,
)
from ..base import DiffusersQuantizer
@@ -62,6 +69,43 @@ if is_torchao_available():
from torchao.quantization import quantize_
def _update_torch_safe_globals():
safe_globals = [
(torch.uint1, "torch.uint1"),
(torch.uint2, "torch.uint2"),
(torch.uint3, "torch.uint3"),
(torch.uint4, "torch.uint4"),
(torch.uint5, "torch.uint5"),
(torch.uint6, "torch.uint6"),
(torch.uint7, "torch.uint7"),
]
try:
from torchao.dtypes import NF4Tensor
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
)
logger.debug(e)
finally:
torch.serialization.add_safe_globals(safe_globals=safe_globals)
if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
):
_update_torch_safe_globals()
logger = logging.get_logger(__name__)

View File

@@ -94,6 +94,7 @@ from .import_utils import (
is_torch_xla_available,
is_torch_xla_version,
is_torchao_available,
is_torchao_version,
is_torchsde_available,
is_torchvision_available,
is_transformers_available,

View File

@@ -868,6 +868,21 @@ def is_gguf_version(operation: str, version: str):
return compare_versions(parse(_gguf_version), operation, version)
def is_torchao_version(operation: str, version: str):
"""
Compares the current torchao version to a given reference with an operation.
Args:
operation (`str`):
A string representation of an operator, such as `">"` or `"<="`
version (`str`):
A version string
"""
if not _is_torchao_available:
return False
return compare_versions(parse(_torchao_version), operation, version)
def is_k_diffusion_version(operation: str, version: str):
"""
Compares the current k-diffusion version to a given reference with an operation.