Compare commits

...

2 Commits

Author SHA1 Message Date
Sayak Paul
f9ef564a73 Merge branch 'main' into auto-offload-improv 2025-11-26 15:23:01 +05:30
sayakpaul
7dad173147 error early in auto_cpu_offload 2025-11-03 11:35:20 +05:30

View File

@@ -160,7 +160,10 @@ class AutoOffloadStrategy:
if len(hooks) == 0:
return []
current_module_size = model.get_memory_footprint()
try:
current_module_size = model.get_memory_footprint()
except AttributeError:
raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")
device_type = execution_device.type
device_module = getattr(torch, device_type, torch.cuda)
@@ -703,7 +706,20 @@ class ComponentsManager:
if not is_accelerate_available():
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
# TODO: add a warning if mem_get_info isn't available on `device`.
if device is None:
device = get_device()
if not isinstance(device, torch.device):
device = torch.device(device)
device_type = device.type
device_module = getattr(torch, device_type, torch.cuda)
if not hasattr(device_module, "mem_get_info"):
raise NotImplementedError(
f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
for name, component in self.components.items():
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
@@ -711,11 +727,7 @@ class ComponentsManager:
self.disable_auto_cpu_offload()
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
if device is None:
device = get_device()
device = torch.device(device)
if device.index is None:
device = torch.device(f"{device.type}:{0}")
all_hooks = []
for name, component in self.components.items():
if isinstance(component, torch.nn.Module):