Compare commits

...

11 Commits

Author SHA1 Message Date
DN6
d9915a7d65 update 2025-03-12 11:44:40 +05:30
DN6
b7a795dbeb update 2025-03-12 11:40:40 +05:30
DN6
438905d63e update 2025-03-12 11:37:27 +05:30
DN6
904f24de5a update 2025-03-12 11:35:18 +05:30
DN6
e123bbcbc4 memmap 2025-03-12 11:23:14 +05:30
DN6
b3fa8c695d remove cpu param dict 2025-03-12 09:02:04 +05:30
DN6
720be2bac5 update 2025-03-12 08:49:45 +05:30
DN6
e74b782aac update 2025-03-12 08:45:09 +05:30
DN6
d6392b4b49 update 2025-03-12 08:18:19 +05:30
DN6
1475026960 sliding-window 2025-03-11 13:56:39 +05:30
DN6
878eb4ce35 update 2025-03-11 13:21:09 +05:30

View File

@@ -29,11 +29,16 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
# Removed PinnedGroupManager - we no longer use pinned memory to avoid CPU memory spikes
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
_LAZY_PREFETCH_GROUP_OFFLOADING = "lazy_prefetch_group_offloading"
# Always use memory-efficient CPU offloading to minimize RAM usage
_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
@@ -56,7 +61,6 @@ class ModuleGroup:
buffers: Optional[List[torch.Tensor]] = None,
non_blocking: bool = False,
stream: Optional[torch.cuda.Stream] = None,
cpu_param_dict: Optional[Dict[torch.nn.Parameter, torch.Tensor]] = None,
onload_self: bool = True,
) -> None:
self.modules = modules
@@ -68,12 +72,8 @@ class ModuleGroup:
self.buffers = buffers
self.non_blocking = non_blocking or stream is not None
self.stream = stream
self.cpu_param_dict = cpu_param_dict
self.onload_self = onload_self
if self.stream is not None and self.cpu_param_dict is None:
raise ValueError("cpu_param_dict must be provided when using stream for data transfer.")
def onload_(self):
r"""Onloads the group of modules to the onload_device."""
context = nullcontext() if self.stream is None else torch.cuda.stream(self.stream)
@@ -82,23 +82,125 @@ class ModuleGroup:
self.stream.synchronize()
with context:
for group_module in self.modules:
group_module.to(self.onload_device, non_blocking=self.non_blocking)
# Use the most efficient module-level transfer when possible
# This approach mirrors how PyTorch handles full model transfers
if self.modules:
for group_module in self.modules:
# Only onload if some parameters are not on the target device
if any(p.device != self.onload_device for p in group_module.parameters()):
try:
# Try the most efficient approach using _apply
if hasattr(group_module, "_apply"):
# This is what module.to() uses internally
def to_device(t):
if t.device != self.onload_device:
if self.onload_device.type == "cuda":
return t.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
return t.to(self.onload_device,
non_blocking=self.non_blocking)
return t
# Apply to all tensors without unnecessary copies
group_module._apply(to_device)
else:
# Fallback to direct parameter transfer
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
except Exception as e:
# If optimization fails, fall back to direct parameter transfer
logger.warning(f"Optimized onloading failed: {e}, falling back to direct method")
for param in group_module.parameters():
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle explicit parameters
if self.parameters is not None:
for param in self.parameters:
param.data = param.data.to(self.onload_device, non_blocking=self.non_blocking)
if param.device != self.onload_device:
if self.onload_device.type == "cuda":
param.data = param.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
param.data = param.data.to(self.onload_device,
non_blocking=self.non_blocking)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.onload_device, non_blocking=self.non_blocking)
if buffer.device != self.onload_device:
if self.onload_device.type == "cuda":
buffer.data = buffer.data.cuda(self.onload_device.index,
non_blocking=self.non_blocking)
else:
buffer.data = buffer.data.to(self.onload_device,
non_blocking=self.non_blocking)
def offload_(self):
r"""Offloads the group of modules to the offload_device."""
if self.stream is not None:
torch.cuda.current_stream().synchronize()
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
# For CPU offloading
if self.offload_device.type == "cpu":
# Synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# Empty GPU cache before offloading to reduce memory fragmentation
if torch.cuda.is_available():
torch.cuda.empty_cache()
# For module groups, use a single, unified approach that is closest to
# the behavior of model.to("cpu")
if self.modules:
for group_module in self.modules:
# Check if we need to offload this module
if any(p.device.type != "cpu" for p in group_module.parameters()):
# Use PyTorch's built-in to() method directly, which preserves
# memory mapping when moving to CPU
try:
# Non-blocking=False for CPU transfers, as it ensures memory is
# immediately available and potentially preserves memory mapping
group_module.to("cpu", non_blocking=False)
except Exception as e:
# If there's any error, fall back to parameter-level offloading
logger.warning(f"Module-level CPU offloading failed: {e}, falling back to parameter-level")
for param in group_module.parameters():
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle explicit parameters - move directly to CPU with non-blocking=False
# which can preserve memory mapping in some PyTorch versions
if self.parameters is not None:
for param in self.parameters:
if param.device.type != "cpu":
param.data = param.data.to("cpu", non_blocking=False)
# Handle buffers
if self.buffers is not None:
for buffer in self.buffers:
if buffer.device.type != "cpu":
buffer.data = buffer.data.to("cpu", non_blocking=False)
# Let Python's normal reference counting handle cleanup
# We don't force garbage collection to avoid slowing down inference
else:
# For non-CPU offloading, synchronize if using stream
if self.stream is not None:
torch.cuda.current_stream().synchronize()
# For non-CPU offloading, use the regular approach
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=self.non_blocking)
if self.parameters is not None:
@@ -108,6 +210,9 @@ class ModuleGroup:
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=self.non_blocking)
# After offloading, we can unpin the memory if configured to do so
# We'll keep it pinned by default for better performance
class GroupOffloadingHook(ModelHook):
r"""
@@ -129,6 +234,7 @@ class GroupOffloadingHook(ModelHook):
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
if self.group.offload_leader == module:
# Offload to CPU
self.group.offload_()
return module
@@ -313,7 +419,8 @@ def apply_group_offloading(
If True, offloading and onloading is done with non-blocking data transfer.
use_stream (`bool`, defaults to `False`):
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
overlapping computation and data transfer.
overlapping computation and data transfer. Memory-efficient CPU offloading is automatically used
to minimize RAM usage by preserving memory mapping benefits and avoiding unnecessary copies.
Example:
```python
@@ -344,12 +451,19 @@ def apply_group_offloading(
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
# We no longer need a pinned group manager as we're not using pinned memory
if offload_type == "block_level":
if num_blocks_per_group is None:
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
_apply_group_offloading_block_level(
module, num_blocks_per_group, offload_device, onload_device, non_blocking, stream
module,
num_blocks_per_group,
offload_device,
onload_device,
non_blocking,
stream,
)
elif offload_type == "leaf_level":
_apply_group_offloading_leaf_level(module, offload_device, onload_device, non_blocking, stream)
@@ -384,12 +498,7 @@ def _apply_group_offloading_block_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for ModuleList and Sequential blocks
modules_with_group_offloading = set()
@@ -411,7 +520,6 @@ def _apply_group_offloading_block_level(
onload_leader=current_modules[0],
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=stream is None,
)
matched_module_groups.append(group)
@@ -448,7 +556,6 @@ def _apply_group_offloading_block_level(
buffers=buffers,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
next_group = matched_module_groups[0] if len(matched_module_groups) > 0 else None
@@ -483,12 +590,7 @@ def _apply_group_offloading_leaf_level(
for overlapping computation and data transfer.
"""
# Create a pinned CPU parameter dict for async data transfer if streams are to be used
cpu_param_dict = None
if stream is not None:
for param in module.parameters():
param.data = param.data.cpu().pin_memory()
cpu_param_dict = {param: param.data for param in module.parameters()}
# We no longer need a CPU parameter dictionary
# Create module groups for leaf modules and apply group offloading hooks
modules_with_group_offloading = set()
@@ -503,7 +605,6 @@ def _apply_group_offloading_leaf_level(
onload_leader=submodule,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(submodule, group, None)
@@ -548,7 +649,6 @@ def _apply_group_offloading_leaf_level(
buffers=buffers,
non_blocking=non_blocking,
stream=stream,
cpu_param_dict=cpu_param_dict,
onload_self=True,
)
_apply_group_offloading_hook(parent_module, group, None)
@@ -567,7 +667,6 @@ def _apply_group_offloading_leaf_level(
buffers=None,
non_blocking=False,
stream=None,
cpu_param_dict=None,
onload_self=True,
)
_apply_lazy_group_offloading_hook(module, unmatched_group, None)