Compare commits

..

7 Commits

Author SHA1 Message Date
Sayak Paul
70067734a2 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-26 11:29:51 +05:30
Sayak Paul
6125a4f540 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-25 08:07:01 +05:30
Sayak Paul
d2666a9d0a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-24 09:06:42 +05:30
sayakpaul
9b9e2e17a6 up 2026-03-23 11:22:36 +05:30
sayakpaul
1a959dc26f switch to swap_tensors. 2026-03-23 10:56:16 +05:30
Sayak Paul
8797398d3b Merge branch 'main' into fix-torchao-groupoffloading 2026-03-23 09:05:37 +05:30
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30
7 changed files with 242 additions and 231 deletions

View File

@@ -10,7 +10,6 @@ permissions:
contents: write
pull-requests: write
issues: read
id-token: write
jobs:
claude-review:

View File

@@ -1105,7 +1105,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -1251,7 +1251,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,54 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -157,9 +205,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -245,18 +300,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.data.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -13,53 +13,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
# ======================== CogVideoX ========================
class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogVideoXTransformer3DModel
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.8]
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.8]
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -75,66 +81,50 @@ class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoXTransformer(CogVideoXTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoXTransformerTraining(CogVideoXTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogVideoXTransformerCompile(CogVideoXTransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== CogVideoX 1.5 ========================
class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogVideoXTransformer3DModel
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -151,29 +141,9 @@ class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
"max_text_seq_length": 8,
"use_rotary_positional_embeddings": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoX15Transformer(CogVideoX15TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoX15TransformerCompile(CogVideoX15TransformerTesterConfig, TorchCompileTesterMixin):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -13,50 +13,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
from ...testing_utils import (
enable_full_determinism,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView3PlusTransformer2DModel
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -69,37 +82,9 @@ class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
"pos_embed_max_size": 8,
"sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogView3PlusTransformer(CogView3PlusTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView3PlusTransformerTraining(CogView3PlusTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView3PlusTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView3PlusTransformerCompile(CogView3PlusTransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -12,46 +12,59 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView4Transformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView4TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView4Transformer2DModel
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def main_input_name(self) -> str:
return "hidden_states"
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
@property
def output_shape(self) -> tuple:
return (4, 8, 8)
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def input_shape(self) -> tuple:
return (4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -62,37 +75,9 @@ class CogView4TransformerTesterConfig(BaseModelTesterConfig):
"time_embed_dim": 8,
"condition_dim": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
}
class TestCogView4Transformer(CogView4TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView4TransformerTraining(CogView4TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView4Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView4TransformerCompile(CogView4TransformerTesterConfig, TorchCompileTesterMixin):
pass