mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
11 Commits
edit-pypi-
...
group-memo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d9915a7d65 | ||
|
|
b7a795dbeb | ||
|
|
438905d63e | ||
|
|
904f24de5a | ||
|
|
e123bbcbc4 | ||
|
|
b3fa8c695d | ||
|
|
720be2bac5 | ||
|
|
e74b782aac | ||
|
|
d6392b4b49 | ||
|
|
1475026960 | ||
|
|
878eb4ce35 |
@@ -29,11 +29,16 @@ if is_accelerate_available():
|
|||||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
_GROUP_OFFLOADING = "group_offloading"
|
_GROUP_OFFLOADING = "group_offloading"
|
||||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||||
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
|
||||||
|
|
||||||
|
# Always use memory-efficient CPU offloading to minimize RAM usage
|
||||||
|
|
||||||
_SUPPORTED_PYTORCH_LAYERS = (
|
_SUPPORTED_PYTORCH_LAYERS = (
|
||||||
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
|
||||||
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
|
||||||
@@ -56,7 +61,6 @@ class ModuleGroup:
|
|||||||
buffers: Optional[List[torch.Tensor]] = None,
|
buffers: Optional[List[torch.Tensor]] = None,
|
||||||
non_blocking: bool = False,
|
non_blocking: bool = False,
|
||||||
stream: Optional[torch.cuda.Stream] = None,
|
stream: Optional[torch.cuda.Stream] = None,
|
||||||
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
|
|
||||||
onload_self: bool = True,
|
onload_self: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.modules = modules
|
self.modules = modules
|
||||||
@@ -68,12 +72,8 @@ class ModuleGroup:
|
|||||||
self.buffers = buffers
|
self.buffers = buffers
|
||||||
self.non_blocking = non_blocking or stream is not None
|
self.non_blocking = non_blocking or stream is not None
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.cpu_param_dict = cpu_param_dict
|
|
||||||
self.onload_self = onload_self
|
self.onload_self = onload_self
|
||||||
|
|
||||||
if self.stream is not None and self.cpu_param_dict is None:
|
|
||||||
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
|
|
||||||
|
|
||||||
def onload_(self):
|
def onload_(self):
|
||||||
r"""Onloads the group of modules to the onload_device."""
|
r"""Onloads the group of modules to the onload_device."""
|
||||||
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
|
||||||
@@ -82,23 +82,125 @@ class ModuleGroup:
|
|||||||
self.stream.synchronize()
|
self.stream.synchronize()
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
for group_module in self.modules:
|
# Use the most efficient module-level transfer when possible
|
||||||
group_module.to(self.onload_device, non_blocking=self.non_blocking)
|
# This approach mirrors how PyTorch handles full model transfers
|
||||||
|
if self.modules:
|
||||||
|
for group_module in self.modules:
|
||||||
|
# Only onload if some parameters are not on the target device
|
||||||
|
if any(p.device != self.onload_device for p in group_module.parameters()):
|
||||||
|
try:
|
||||||
|
# Try the most efficient approach using _apply
|
||||||
|
if hasattr(group_module, "_apply"):
|
||||||
|
# This is what module.to() uses internally
|
||||||
|
def to_device(t):
|
||||||
|
if t.device != self.onload_device:
|
||||||
|
if self.onload_device.type == "cuda":
|
||||||
|
return t.cuda(self.onload_device.index,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
else:
|
||||||
|
return t.to(self.onload_device,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
return t
|
||||||
|
|
||||||
|
# Apply to all tensors without unnecessary copies
|
||||||
|
group_module._apply(to_device)
|
||||||
|
else:
|
||||||
|
# Fallback to direct parameter transfer
|
||||||
|
for param in group_module.parameters():
|
||||||
|
if param.device != self.onload_device:
|
||||||
|
if self.onload_device.type == "cuda":
|
||||||
|
param.data = param.data.cuda(self.onload_device.index,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
else:
|
||||||
|
param.data = param.data.to(self.onload_device,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
except Exception as e:
|
||||||
|
# If optimization fails, fall back to direct parameter transfer
|
||||||
|
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
|
||||||
|
for param in group_module.parameters():
|
||||||
|
if param.device != self.onload_device:
|
||||||
|
if self.onload_device.type == "cuda":
|
||||||
|
param.data = param.data.cuda(self.onload_device.index,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
else:
|
||||||
|
param.data = param.data.to(self.onload_device,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
|
||||||
|
# Handle explicit parameters
|
||||||
if self.parameters is not None:
|
if self.parameters is not None:
|
||||||
for param in self.parameters:
|
for param in self.parameters:
|
||||||
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
|
if param.device != self.onload_device:
|
||||||
|
if self.onload_device.type == "cuda":
|
||||||
|
param.data = param.data.cuda(self.onload_device.index,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
else:
|
||||||
|
param.data = param.data.to(self.onload_device,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
|
||||||
|
# Handle buffers
|
||||||
if self.buffers is not None:
|
if self.buffers is not None:
|
||||||
for buffer in self.buffers:
|
for buffer in self.buffers:
|
||||||
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
|
if buffer.device != self.onload_device:
|
||||||
|
if self.onload_device.type == "cuda":
|
||||||
|
buffer.data = buffer.data.cuda(self.onload_device.index,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
else:
|
||||||
|
buffer.data = buffer.data.to(self.onload_device,
|
||||||
|
non_blocking=self.non_blocking)
|
||||||
|
|
||||||
def offload_(self):
|
def offload_(self):
|
||||||
r"""Offloads the group of modules to the offload_device."""
|
r"""Offloads the group of modules to the offload_device."""
|
||||||
if self.stream is not None:
|
# For CPU offloading
|
||||||
torch.cuda.current_stream().synchronize()
|
if self.offload_device.type == "cpu":
|
||||||
for group_module in self.modules:
|
# Synchronize if using stream
|
||||||
for param in group_module.parameters():
|
if self.stream is not None:
|
||||||
param.data = self.cpu_param_dict[param]
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
# Empty GPU cache before offloading to reduce memory fragmentation
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
# For module groups, use a single, unified approach that is closest to
|
||||||
|
# the behavior of model.to("cpu")
|
||||||
|
if self.modules:
|
||||||
|
for group_module in self.modules:
|
||||||
|
# Check if we need to offload this module
|
||||||
|
if any(p.device.type != "cpu" for p in group_module.parameters()):
|
||||||
|
# Use PyTorch's built-in to() method directly, which preserves
|
||||||
|
# memory mapping when moving to CPU
|
||||||
|
try:
|
||||||
|
# Non-blocking=False for CPU transfers, as it ensures memory is
|
||||||
|
# immediately available and potentially preserves memory mapping
|
||||||
|
group_module.to("cpu", non_blocking=False)
|
||||||
|
except Exception as e:
|
||||||
|
# If there's any error, fall back to parameter-level offloading
|
||||||
|
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
|
||||||
|
for param in group_module.parameters():
|
||||||
|
if param.device.type != "cpu":
|
||||||
|
param.data = param.data.to("cpu", non_blocking=False)
|
||||||
|
|
||||||
|
# Handle explicit parameters - move directly to CPU with non-blocking=False
|
||||||
|
# which can preserve memory mapping in some PyTorch versions
|
||||||
|
if self.parameters is not None:
|
||||||
|
for param in self.parameters:
|
||||||
|
if param.device.type != "cpu":
|
||||||
|
param.data = param.data.to("cpu", non_blocking=False)
|
||||||
|
|
||||||
|
# Handle buffers
|
||||||
|
if self.buffers is not None:
|
||||||
|
for buffer in self.buffers:
|
||||||
|
if buffer.device.type != "cpu":
|
||||||
|
buffer.data = buffer.data.to("cpu", non_blocking=False)
|
||||||
|
|
||||||
|
# Let Python's normal reference counting handle cleanup
|
||||||
|
# We don't force garbage collection to avoid slowing down inference
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# For non-CPU offloading, synchronize if using stream
|
||||||
|
if self.stream is not None:
|
||||||
|
torch.cuda.current_stream().synchronize()
|
||||||
|
|
||||||
|
# For non-CPU offloading, use the regular approach
|
||||||
for group_module in self.modules:
|
for group_module in self.modules:
|
||||||
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
group_module.to(self.offload_device, non_blocking=self.non_blocking)
|
||||||
if self.parameters is not None:
|
if self.parameters is not None:
|
||||||
@@ -108,6 +210,9 @@ class ModuleGroup:
|
|||||||
for buffer in self.buffers:
|
for buffer in self.buffers:
|
||||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
|
||||||
|
|
||||||
|
# After offloading, we can unpin the memory if configured to do so
|
||||||
|
# We'll keep it pinned by default for better performance
|
||||||
|
|
||||||
|
|
||||||
class GroupOffloadingHook(ModelHook):
|
class GroupOffloadingHook(ModelHook):
|
||||||
r"""
|
r"""
|
||||||
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
|
|||||||
|
|
||||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||||
if self.group.offload_leader == module:
|
if self.group.offload_leader == module:
|
||||||
|
# Offload to CPU
|
||||||
self.group.offload_()
|
self.group.offload_()
|
||||||
return module
|
return module
|
||||||
|
|
||||||
@@ -313,7 +419,8 @@ def apply_group_offloading(
|
|||||||
If True, offloading and onloading is done with non-blocking data transfer.
|
If True, offloading and onloading is done with non-blocking data transfer.
|
||||||
use_stream (`bool`, defaults to `False`):
|
use_stream (`bool`, defaults to `False`):
|
||||||
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
|
||||||
overlapping computation and data transfer.
|
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
|
||||||
|
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
```python
|
```python
|
||||||
@@ -344,12 +451,19 @@ def apply_group_offloading(
|
|||||||
|
|
||||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||||
|
|
||||||
|
# We no longer need a pinned group manager as we're not using pinned memory
|
||||||
|
|
||||||
if offload_type == "block_level":
|
if offload_type == "block_level":
|
||||||
if num_blocks_per_group is None:
|
if num_blocks_per_group is None:
|
||||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||||
|
|
||||||
_apply_group_offloading_block_level(
|
_apply_group_offloading_block_level(
|
||||||
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
|
module,
|
||||||
|
num_blocks_per_group,
|
||||||
|
offload_device,
|
||||||
|
onload_device,
|
||||||
|
non_blocking,
|
||||||
|
stream,
|
||||||
)
|
)
|
||||||
elif offload_type == "leaf_level":
|
elif offload_type == "leaf_level":
|
||||||
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
|
||||||
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
|
|||||||
for overlapping computation and data transfer.
|
for overlapping computation and data transfer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
# We no longer need a CPU parameter dictionary
|
||||||
cpu_param_dict = None
|
|
||||||
if stream is not None:
|
|
||||||
for param in module.parameters():
|
|
||||||
param.data = param.data.cpu().pin_memory()
|
|
||||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
|
||||||
|
|
||||||
# Create module groups for ModuleList and Sequential blocks
|
# Create module groups for ModuleList and Sequential blocks
|
||||||
modules_with_group_offloading = set()
|
modules_with_group_offloading = set()
|
||||||
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
|
|||||||
onload_leader=current_modules[0],
|
onload_leader=current_modules[0],
|
||||||
non_blocking=non_blocking,
|
non_blocking=non_blocking,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
cpu_param_dict=cpu_param_dict,
|
|
||||||
onload_self=stream is None,
|
onload_self=stream is None,
|
||||||
)
|
)
|
||||||
matched_module_groups.append(group)
|
matched_module_groups.append(group)
|
||||||
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
|
|||||||
buffers=buffers,
|
buffers=buffers,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
stream=None,
|
stream=None,
|
||||||
cpu_param_dict=None,
|
|
||||||
onload_self=True,
|
onload_self=True,
|
||||||
)
|
)
|
||||||
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
|
||||||
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
|
|||||||
for overlapping computation and data transfer.
|
for overlapping computation and data transfer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
|
# We no longer need a CPU parameter dictionary
|
||||||
cpu_param_dict = None
|
|
||||||
if stream is not None:
|
|
||||||
for param in module.parameters():
|
|
||||||
param.data = param.data.cpu().pin_memory()
|
|
||||||
cpu_param_dict = {param: param.data for param in module.parameters()}
|
|
||||||
|
|
||||||
# Create module groups for leaf modules and apply group offloading hooks
|
# Create module groups for leaf modules and apply group offloading hooks
|
||||||
modules_with_group_offloading = set()
|
modules_with_group_offloading = set()
|
||||||
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
|
|||||||
onload_leader=submodule,
|
onload_leader=submodule,
|
||||||
non_blocking=non_blocking,
|
non_blocking=non_blocking,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
cpu_param_dict=cpu_param_dict,
|
|
||||||
onload_self=True,
|
onload_self=True,
|
||||||
)
|
)
|
||||||
_apply_group_offloading_hook(submodule, group, None)
|
_apply_group_offloading_hook(submodule, group, None)
|
||||||
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
|
|||||||
buffers=buffers,
|
buffers=buffers,
|
||||||
non_blocking=non_blocking,
|
non_blocking=non_blocking,
|
||||||
stream=stream,
|
stream=stream,
|
||||||
cpu_param_dict=cpu_param_dict,
|
|
||||||
onload_self=True,
|
onload_self=True,
|
||||||
)
|
)
|
||||||
_apply_group_offloading_hook(parent_module, group, None)
|
_apply_group_offloading_hook(parent_module, group, None)
|
||||||
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
|
|||||||
buffers=None,
|
buffers=None,
|
||||||
non_blocking=False,
|
non_blocking=False,
|
||||||
stream=None,
|
stream=None,
|
||||||
cpu_param_dict=None,
|
|
||||||
onload_self=True,
|
onload_self=True,
|
||||||
)
|
)
|
||||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||||
|
|||||||
Reference in New Issue
Block a user