mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
12 Commits
quantizer-
...
group-memo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9915a7d65 | ||
|
|
b7a795dbeb | ||
|
|
438905d63e | ||
|
|
904f24de5a | ||
|
|
e123bbcbc4 | ||
|
|
b3fa8c695d | ||
|
|
720be2bac5 | ||
|
|
e74b782aac | ||
|
|
d6392b4b49 | ||
|
|
1475026960 | ||
|
|
878eb4ce35 | ||
|
|
9add071592 |
@@ -126,7 +126,7 @@ image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
|
||||
image.save("output.png")
|
||||
```
|
||||
|
||||
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
If you are using `torch<=2.6.0`, some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@@ -2,20 +2,14 @@ __version__ = "0.33.0.dev0"
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from diffusers.quantizers import quantization_config
|
||||
from diffusers.utils import dummy_gguf_objects
|
||||
from diffusers.utils.import_utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_gguf_available,
|
||||
is_optimum_quanto_version,
|
||||
is_torchao_available,
|
||||
)
|
||||
|
||||
from .utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
is_accelerate_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
@@ -24,6 +18,7 @@ from .utils import (
|
||||
is_scipy_available,
|
||||
is_sentencepiece_available,
|
||||
is_torch_available,
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
@@ -65,7 +60,7 @@ _import_structure = {
|
||||
}
|
||||
|
||||
try:
|
||||
if not is_bitsandbytes_available():
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_bitsandbytes_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_bitsandbytes_objects
|
||||
@@ -77,7 +72,7 @@ else:
|
||||
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")
|
||||
|
||||
try:
|
||||
if not is_gguf_available():
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_gguf_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_gguf_objects
|
||||
@@ -89,7 +84,7 @@ else:
|
||||
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")
|
||||
|
||||
try:
|
||||
if not is_torchao_available():
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_torchao_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torchao_objects
|
||||
@@ -101,7 +96,7 @@ else:
|
||||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")
|
||||
|
||||
try:
|
||||
if not is_optimum_quanto_available():
|
||||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_optimum_quanto_objects
|
||||
@@ -112,7 +107,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["quantizers.quantization_config"].append("QuantoConfig")
|
||||
|
||||
|
||||
try:
|
||||
if not is_onnx_available():
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -29,11 +29,16 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||
|
||||
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||
|
||||
_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||
@@ -56,7 +61,6 @@ class ModuleGroup:
|
||||
buffers: Optional[List[torch.Tensor]] = None,
|
||||
non_blocking: bool = False,
|
||||
stream: Optional[torch.cuda.Stream] = None,
|
||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
||||
onload_self: bool = True,
|
||||
) -> None:
|
||||
self.modules = modules
|
||||
@@ -68,12 +72,8 @@ class ModuleGroup:
|
||||
self.buffers = buffers
|
||||
self.non_blocking = non_blocking or stream is not None
|
||||
self.stream = stream
|
||||
self.cpu_param_dict = cpu_param_dict
|
||||
self.onload_self = onload_self
|
||||
|
||||
if self.stream is not None and self.cpu_param_dict is None:
|
||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
||||
|
||||
def onload_(self):
|
||||
r"""Onloads the group of modules to the onload_device."""
|
||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||
@@ -82,23 +82,125 @@ class ModuleGroup:
|
||||
self.stream.synchronize()
|
||||
|
||||
with context:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
# Use the most efficient module-level transfer when possible
|
||||
# This approach mirrors how PyTorch handles full model transfers
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Only onload if some parameters are not on the target device
|
||||
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||
try:
|
||||
# Try the most efficient approach using _apply
|
||||
if hasattr(group_module, "_apply"):
|
||||
# This is what module.to() uses internally
|
||||
def to_device(t):
|
||||
if t.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
return t.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
return t.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
return t
|
||||
|
||||
# Apply to all tensors without unnecessary copies
|
||||
group_module._apply(to_device)
|
||||
else:
|
||||
# Fallback to direct parameter transfer
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
except Exception as e:
|
||||
# If optimization fails, fall back to direct parameter transfer
|
||||
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||
for param in group_module.parameters():
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle explicit parameters
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if param.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
param.data = param.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
param.data = param.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if buffer.device != self.onload_device:
|
||||
if self.onload_device.type == "cuda":
|
||||
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||
non_blocking=self.non_blocking)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.onload_device,
|
||||
non_blocking=self.non_blocking)
|
||||
|
||||
def offload_(self):
|
||||
r"""Offloads the group of modules to the offload_device."""
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
# For CPU offloading
|
||||
if self.offload_device.type == "cpu":
|
||||
# Synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# For module groups, use a single, unified approach that is closest to
|
||||
# the behavior of model.to("cpu")
|
||||
if self.modules:
|
||||
for group_module in self.modules:
|
||||
# Check if we need to offload this module
|
||||
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||
# Use PyTorch's built-in to() method directly, which preserves
|
||||
# memory mapping when moving to CPU
|
||||
try:
|
||||
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||
# immediately available and potentially preserves memory mapping
|
||||
group_module.to("cpu", non_blocking=False)
|
||||
except Exception as e:
|
||||
# If there's any error, fall back to parameter-level offloading
|
||||
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||
for param in group_module.parameters():
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||
# which can preserve memory mapping in some PyTorch versions
|
||||
if self.parameters is not None:
|
||||
for param in self.parameters:
|
||||
if param.device.type != "cpu":
|
||||
param.data = param.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Handle buffers
|
||||
if self.buffers is not None:
|
||||
for buffer in self.buffers:
|
||||
if buffer.device.type != "cpu":
|
||||
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||
|
||||
# Let Python's normal reference counting handle cleanup
|
||||
# We don't force garbage collection to avoid slowing down inference
|
||||
|
||||
else:
|
||||
# For non-CPU offloading, synchronize if using stream
|
||||
if self.stream is not None:
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
# For non-CPU offloading, use the regular approach
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
if self.parameters is not None:
|
||||
@@ -108,6 +210,9 @@ class ModuleGroup:
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||
|
||||
# After offloading, we can unpin the memory if configured to do so
|
||||
# We'll keep it pinned by default for better performance
|
||||
|
||||
|
||||
class GroupOffloadingHook(ModelHook):
|
||||
r"""
|
||||
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
# Offload to CPU
|
||||
self.group.offload_()
|
||||
return module
|
||||
|
||||
@@ -313,7 +419,8 @@ def apply_group_offloading(
|
||||
If True, offloading and onloading is done with non-blocking data transfer.
|
||||
use_stream (`bool`, defaults to `False`):
|
||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||
overlapping computation and data transfer.
|
||||
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||
|
||||
Example:
|
||||
```python
|
||||
@@ -344,12 +451,19 @@ def apply_group_offloading(
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
# We no longer need a pinned group manager as we're not using pinned memory
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
||||
module,
|
||||
num_blocks_per_group,
|
||||
offload_device,
|
||||
onload_device,
|
||||
non_blocking,
|
||||
stream,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=stream is None,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
|
||||
buffers=buffers,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
|
||||
for overlapping computation and data transfer.
|
||||
"""
|
||||
|
||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
||||
cpu_param_dict = None
|
||||
if stream is not None:
|
||||
for param in module.parameters():
|
||||
param.data = param.data.cpu().pin_memory()
|
||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
||||
# We no longer need a CPU parameter dictionary
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
cpu_param_dict=cpu_param_dict,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
|
||||
buffers=None,
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
cpu_param_dict=None,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
|
||||
@@ -23,7 +23,14 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ...utils import get_module_from_name, is_torch_available, is_torch_version, is_torchao_available, logging
|
||||
from ...utils import (
|
||||
get_module_from_name,
|
||||
is_torch_available,
|
||||
is_torch_version,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
logging,
|
||||
)
|
||||
from ..base import DiffusersQuantizer
|
||||
|
||||
|
||||
@@ -62,6 +69,43 @@ if is_torchao_available():
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
|
||||
def _update_torch_safe_globals():
|
||||
safe_globals = [
|
||||
(torch.uint1, "torch.uint1"),
|
||||
(torch.uint2, "torch.uint2"),
|
||||
(torch.uint3, "torch.uint3"),
|
||||
(torch.uint4, "torch.uint4"),
|
||||
(torch.uint5, "torch.uint5"),
|
||||
(torch.uint6, "torch.uint6"),
|
||||
(torch.uint7, "torch.uint7"),
|
||||
]
|
||||
try:
|
||||
from torchao.dtypes import NF4Tensor
|
||||
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl
|
||||
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor
|
||||
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor
|
||||
|
||||
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor])
|
||||
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.warning(
|
||||
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`"
|
||||
)
|
||||
logger.debug(e)
|
||||
|
||||
finally:
|
||||
torch.serialization.add_safe_globals(safe_globals=safe_globals)
|
||||
|
||||
|
||||
if (
|
||||
is_torch_available()
|
||||
and is_torch_version(">=", "2.6.0")
|
||||
and is_torchao_available()
|
||||
and is_torchao_version(">=", "0.7.0")
|
||||
):
|
||||
_update_torch_safe_globals()
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ from .import_utils import (
|
||||
is_torch_xla_available,
|
||||
is_torch_xla_version,
|
||||
is_torchao_available,
|
||||
is_torchao_version,
|
||||
is_torchsde_available,
|
||||
is_torchvision_available,
|
||||
is_transformers_available,
|
||||
|
||||
@@ -868,6 +868,21 @@ def is_gguf_version(operation: str, version: str):
|
||||
return compare_versions(parse(_gguf_version), operation, version)
|
||||
|
||||
|
||||
def is_torchao_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current torchao version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _is_torchao_available:
|
||||
return False
|
||||
return compare_versions(parse(_torchao_version), operation, version)
|
||||
|
||||
|
||||
def is_k_diffusion_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current k-diffusion version to a given reference with an operation.
|
||||
|
||||
Reference in New Issue
Block a user