Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
a8cef0740a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-27 21:16:15 +05:30
Sayak Paul
7da22b9db5 [ci] include checkout step in claude review workflow (#13352)
up
2026-03-27 17:28:31 +05:30
Sayak Paul
70067734a2 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-26 11:29:51 +05:30
Sayak Paul
6125a4f540 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-25 08:07:01 +05:30
Sayak Paul
d2666a9d0a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-24 09:06:42 +05:30
sayakpaul
9b9e2e17a6 up 2026-03-23 11:22:36 +05:30
sayakpaul
1a959dc26f switch to swap_tensors. 2026-03-23 10:56:16 +05:30
Sayak Paul
8797398d3b Merge branch 'main' into fix-torchao-groupoffloading 2026-03-23 09:05:37 +05:30
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30
3 changed files with 164 additions and 118 deletions

View File

@@ -32,6 +32,9 @@ jobs:
)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,54 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -157,9 +205,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -245,18 +300,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -12,46 +12,60 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -66,68 +80,57 @@ class CosmosTransformerTesterConfig(BaseModelTesterConfig):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer."""
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer."""
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"condition_mask": condition_mask,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -142,40 +145,8 @@ class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer (Video-to-World)."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}