mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 21:04:56 +08:00
Compare commits
1 Commits
remove-unn
...
group-offl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff690a1324 |
@@ -156,38 +156,33 @@ class ModuleGroup:
|
|||||||
finally:
|
finally:
|
||||||
pinned_dict = None
|
pinned_dict = None
|
||||||
|
|
||||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
def _transfer_tensor_to_device(self, tensor, source_tensor):
|
||||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||||
if self.record_stream:
|
|
||||||
tensor.data.record_stream(default_stream)
|
|
||||||
|
|
||||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
def _process_tensors_from_modules(self, pinned_memory=None):
|
||||||
for group_module in self.modules:
|
for group_module in self.modules:
|
||||||
for param in group_module.parameters():
|
for param in group_module.parameters():
|
||||||
source = pinned_memory[param] if pinned_memory else param.data
|
source = pinned_memory[param] if pinned_memory else param.data
|
||||||
self._transfer_tensor_to_device(param, source, default_stream)
|
self._transfer_tensor_to_device(param, source)
|
||||||
for buffer in group_module.buffers():
|
for buffer in group_module.buffers():
|
||||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||||
self._transfer_tensor_to_device(buffer, source, default_stream)
|
self._transfer_tensor_to_device(buffer, source)
|
||||||
|
|
||||||
for param in self.parameters:
|
for param in self.parameters:
|
||||||
source = pinned_memory[param] if pinned_memory else param.data
|
source = pinned_memory[param] if pinned_memory else param.data
|
||||||
self._transfer_tensor_to_device(param, source, default_stream)
|
self._transfer_tensor_to_device(param, source)
|
||||||
|
|
||||||
for buffer in self.buffers:
|
for buffer in self.buffers:
|
||||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||||
self._transfer_tensor_to_device(buffer, source, default_stream)
|
self._transfer_tensor_to_device(buffer, source)
|
||||||
|
|
||||||
def _onload_from_disk(self):
|
def _onload_from_disk(self):
|
||||||
if self.stream is not None:
|
if self.stream is not None:
|
||||||
# Wait for previous Host->Device transfer to complete
|
|
||||||
self.stream.synchronize()
|
self.stream.synchronize()
|
||||||
|
|
||||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||||
current_stream = self._torch_accelerator_module.current_stream() if self.record_stream else None
|
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
# Load to CPU (if using streams) or directly to target device, pin, and async copy to device
|
|
||||||
device = str(self.onload_device) if self.stream is None else "cpu"
|
device = str(self.onload_device) if self.stream is None else "cpu"
|
||||||
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
loaded_tensors = safetensors.torch.load_file(self.safetensors_file_path, device=device)
|
||||||
|
|
||||||
@@ -195,8 +190,6 @@ class ModuleGroup:
|
|||||||
for key, tensor_obj in self.key_to_tensor.items():
|
for key, tensor_obj in self.key_to_tensor.items():
|
||||||
pinned_tensor = loaded_tensors[key].pin_memory()
|
pinned_tensor = loaded_tensors[key].pin_memory()
|
||||||
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
tensor_obj.data = pinned_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||||
if self.record_stream:
|
|
||||||
tensor_obj.data.record_stream(current_stream)
|
|
||||||
else:
|
else:
|
||||||
onload_device = (
|
onload_device = (
|
||||||
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
self.onload_device.type if isinstance(self.onload_device, torch.device) else self.onload_device
|
||||||
@@ -207,45 +200,57 @@ class ModuleGroup:
|
|||||||
|
|
||||||
def _onload_from_memory(self):
|
def _onload_from_memory(self):
|
||||||
if self.stream is not None:
|
if self.stream is not None:
|
||||||
# Wait for previous Host->Device transfer to complete
|
|
||||||
self.stream.synchronize()
|
self.stream.synchronize()
|
||||||
|
|
||||||
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
context = nullcontext() if self.stream is None else self._torch_accelerator_module.stream(self.stream)
|
||||||
default_stream = self._torch_accelerator_module.current_stream() if self.stream is not None else None
|
|
||||||
|
|
||||||
with context:
|
with context:
|
||||||
if self.stream is not None:
|
if self.stream is not None:
|
||||||
with self._pinned_memory_tensors() as pinned_memory:
|
with self._pinned_memory_tensors() as pinned_memory:
|
||||||
self._process_tensors_from_modules(pinned_memory, default_stream=default_stream)
|
self._process_tensors_from_modules(pinned_memory)
|
||||||
else:
|
else:
|
||||||
self._process_tensors_from_modules(None)
|
self._process_tensors_from_modules(None)
|
||||||
|
|
||||||
def _offload_to_disk(self):
|
def _offload_to_disk(self):
|
||||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
|
||||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
|
||||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
|
||||||
# we perform a write.
|
|
||||||
# Check if the file has been saved in this session or if it already exists on disk.
|
|
||||||
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
if not self._is_offloaded_to_disk and not os.path.exists(self.safetensors_file_path):
|
||||||
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.safetensors_file_path), exist_ok=True)
|
||||||
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
tensors_to_save = {key: tensor.data.to(self.offload_device) for tensor, key in self.tensor_to_key.items()}
|
||||||
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
safetensors.torch.save_file(tensors_to_save, self.safetensors_file_path)
|
||||||
|
|
||||||
# The group is now considered offloaded to disk for the rest of the session.
|
|
||||||
self._is_offloaded_to_disk = True
|
self._is_offloaded_to_disk = True
|
||||||
|
|
||||||
# We do this to free up the RAM which is still holding the up tensor data.
|
if self.stream is not None:
|
||||||
|
if self.record_stream:
|
||||||
|
current_stream = self._torch_accelerator_module.current_stream()
|
||||||
|
for tensor_obj in self.tensor_to_key.keys():
|
||||||
|
tensor_obj.data.record_stream(current_stream)
|
||||||
|
else:
|
||||||
|
self._torch_accelerator_module.current_stream().synchronize()
|
||||||
|
|
||||||
for tensor_obj in self.tensor_to_key.keys():
|
for tensor_obj in self.tensor_to_key.keys():
|
||||||
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
tensor_obj.data = torch.empty_like(tensor_obj.data, device=self.offload_device)
|
||||||
|
|
||||||
def _offload_to_memory(self):
|
def _offload_to_memory(self):
|
||||||
if self.stream is not None:
|
if self.stream is not None:
|
||||||
if not self.record_stream:
|
if self.record_stream:
|
||||||
|
current_stream = self._torch_accelerator_module.current_stream()
|
||||||
|
for group_module in self.modules:
|
||||||
|
for param in group_module.parameters():
|
||||||
|
param.data.record_stream(current_stream)
|
||||||
|
for buffer in group_module.buffers():
|
||||||
|
buffer.data.record_stream(current_stream)
|
||||||
|
for param in self.parameters:
|
||||||
|
param.data.record_stream(current_stream)
|
||||||
|
for buffer in self.buffers:
|
||||||
|
buffer.data.record_stream(current_stream)
|
||||||
|
else:
|
||||||
self._torch_accelerator_module.current_stream().synchronize()
|
self._torch_accelerator_module.current_stream().synchronize()
|
||||||
|
|
||||||
for group_module in self.modules:
|
for group_module in self.modules:
|
||||||
for param in group_module.parameters():
|
for param in group_module.parameters():
|
||||||
param.data = self.cpu_param_dict[param]
|
param.data = self.cpu_param_dict[param]
|
||||||
|
for buffer in group_module.buffers():
|
||||||
|
buffer.data = self.cpu_param_dict[buffer]
|
||||||
for param in self.parameters:
|
for param in self.parameters:
|
||||||
param.data = self.cpu_param_dict[param]
|
param.data = self.cpu_param_dict[param]
|
||||||
for buffer in self.buffers:
|
for buffer in self.buffers:
|
||||||
|
|||||||
Reference in New Issue
Block a user