Compare commits

..

4 Commits

Author SHA1 Message Date
Sayak Paul
d8f6063c27 Merge branch 'main' into cog-tests 2026-03-30 09:01:58 +05:30
Howard Zhang
f2be8bd6b3 change minimum version guard for torchao to 0.15.0 (#13355) 2026-03-28 09:11:51 +05:30
Sayak Paul
7da22b9db5 [ci] include checkout step in claude review workflow (#13352)
up
2026-03-27 17:28:31 +05:30
DN6
f7405f2b44 update 2026-03-26 16:41:25 +05:30
8 changed files with 316 additions and 285 deletions

View File

@@ -32,6 +32,9 @@ jobs:
)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
self.post_init()
def post_init(self):
if is_torchao_version("<=", "0.9.0"):
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
from torchao.quantization.quant_api import AOBaseConfig
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")

View File

@@ -113,7 +113,7 @@ if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.7.0")
and is_torchao_version(">=", "0.15.0")
):
_update_torch_safe_globals()
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
torchao_version = version.parse(importlib.metadata.version("torchao"))
if torchao_version < version.parse("0.15.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False

View File

@@ -13,59 +13,53 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogVideoXTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.8]
# ======================== CogVideoX ========================
class CogVideoXTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogVideoXTransformer3DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def model_split_percents(self) -> list:
return [0.7, 0.7, 0.8]
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -81,50 +75,66 @@ class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase):
"temporal_compression_ratio": 4,
"max_text_seq_length": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogVideoXTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 2
num_frames = 1
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
hidden_states = torch.randn((batch_size, num_frames, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoXTransformer(CogVideoXTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoXTransformerTraining(CogVideoXTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogVideoXTransformerCompile(CogVideoXTransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== CogVideoX 1.5 ========================
class CogVideoX15TransformerTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return CogVideoXTransformer3DModel
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def output_shape(self):
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
# Product of num_attention_heads * attention_head_dim must be divisible by 16 for 3D positional embeddings.
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"num_attention_heads": 2,
"attention_head_dim": 8,
"in_channels": 4,
@@ -141,9 +151,29 @@ class CogVideoX1_5TransformerTests(ModelTesterMixin, unittest.TestCase):
"max_text_seq_length": 8,
"use_rotary_positional_embeddings": True,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogVideoXTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 2
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_frames, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogVideoX15Transformer(CogVideoX15TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogVideoX15TransformerCompile(CogVideoX15TransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -13,63 +13,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView3PlusTransformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView3PlusTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.6, 0.6]
class CogView3PlusTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView3PlusTransformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def output_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def input_shape(self) -> tuple:
return (1, 4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
"timestep": timestep,
}
@property
def input_shape(self):
return (1, 4, 8, 8)
@property
def output_shape(self):
return (1, 4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -82,9 +69,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
"pos_embed_max_size": 8,
"sample_size": 8,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestCogView3PlusTransformer(CogView3PlusTransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView3PlusTransformerTraining(CogView3PlusTransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView3PlusTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView3PlusTransformerCompile(CogView3PlusTransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -12,59 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CogView4Transformer2DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = CogView4Transformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class CogView4TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CogView4Transformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
def main_input_name(self) -> str:
return "hidden_states"
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def output_shape(self) -> tuple:
return (4, 8, 8)
@property
def input_shape(self) -> tuple:
return (4, 8, 8)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"original_size": original_size,
"target_size": target_size,
"crop_coords": crop_coords,
}
@property
def input_shape(self):
return (4, 8, 8)
@property
def output_shape(self):
return (4, 8, 8)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"in_channels": 4,
"num_layers": 2,
@@ -75,9 +62,37 @@ class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
"time_embed_dim": 8,
"condition_dim": 4,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = 8
width = 8
embedding_dim = 8
sequence_length = 8
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"original_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"target_size": torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
"crop_coords": torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device),
}
class TestCogView4Transformer(CogView4TransformerTesterConfig, ModelTesterMixin):
pass
class TestCogView4TransformerTraining(CogView4TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CogView4Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestCogView4TransformerCompile(CogView4TransformerTesterConfig, TorchCompileTesterMixin):
pass

View File

@@ -12,46 +12,60 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -66,68 +80,57 @@ class CosmosTransformerTesterConfig(BaseModelTesterConfig):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer."""
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer."""
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"condition_mask": condition_mask,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -142,40 +145,8 @@ class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer (Video-to-World)."""
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}

View File

@@ -14,13 +14,11 @@
# limitations under the License.
import gc
import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -82,18 +80,17 @@ if is_torchao_available():
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.utils import get_model_size_in_bytes
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@@ -128,7 +125,7 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -527,7 +524,7 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
@require_torchao_version_greater_or_equal("0.9.0")
@require_torchao_version_greater_or_equal("0.15.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
@@ -540,7 +537,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -650,7 +647,7 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -696,7 +693,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -854,7 +851,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.14.0")
@require_torchao_version_greater_or_equal("0.15.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):