Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
2bba42a7e5 Merge branch 'main' into cosmos-test-refactor 2026-03-27 16:12:14 +05:30
DN6
4f82a6f9a2 update 2026-03-26 09:36:10 +05:30
5 changed files with 124 additions and 95 deletions

View File

@@ -32,9 +32,6 @@ jobs:
)
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 1
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}

View File

@@ -470,8 +470,8 @@ class TorchAoConfig(QuantizationConfigMixin):
self.post_init()
def post_init(self):
if is_torchao_version("<", "0.15.0"):
raise ValueError("TorchAoConfig requires torchao >= 0.15.0. Please upgrade with `pip install -U torchao`.")
if is_torchao_version("<=", "0.9.0"):
raise ValueError("TorchAoConfig requires torchao > 0.9.0. Please upgrade with `pip install -U torchao`.")
from torchao.quantization.quant_api import AOBaseConfig
@@ -495,8 +495,8 @@ class TorchAoConfig(QuantizationConfigMixin):
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
if not is_torchao_version(">=", "0.15.0"):
raise NotImplementedError("TorchAoConfig requires torchao >= 0.15.0 for construction from dict")
if not is_torchao_version(">", "0.9.0"):
raise NotImplementedError("TorchAoConfig requires torchao > 0.9.0 for construction from dict")
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")

View File

@@ -113,7 +113,7 @@ if (
is_torch_available()
and is_torch_version(">=", "2.6.0")
and is_torchao_available()
and is_torchao_version(">=", "0.15.0")
and is_torchao_version(">=", "0.7.0")
):
_update_torch_safe_globals()
@@ -168,10 +168,10 @@ class TorchAoHfQuantizer(DiffusersQuantizer):
raise ImportError(
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`"
)
torchao_version = version.parse(importlib.metadata.version("torchao"))
if torchao_version < version.parse("0.15.0"):
torchao_version = version.parse(importlib.metadata.version("torch"))
if torchao_version < version.parse("0.7.0"):
raise RuntimeError(
f"The minimum required version of `torchao` is 0.15.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`."
)
self.offload = False

View File

@@ -12,60 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -80,57 +66,68 @@ class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"condition_mask": condition_mask,
"padding_mask": padding_mask,
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer."""
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer."""
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return CosmosTransformer3DModel
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"in_channels": 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -145,8 +142,40 @@ class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestC
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer (Video-to-World)."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}

View File

@@ -14,11 +14,13 @@
# limitations under the License.
import gc
import importlib.metadata
import tempfile
import unittest
from typing import List
import numpy as np
from packaging import version
from parameterized import parameterized
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
@@ -80,17 +82,18 @@ if is_torchao_available():
Float8WeightOnlyConfig,
Int4WeightOnlyConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
)
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
from torchao.utils import get_model_size_in_bytes
if version.parse(importlib.metadata.version("torchao")) >= version.Version("0.10.0"):
from torchao.quantization import Int8DynamicActivationIntxWeightConfig, IntxWeightOnlyConfig
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@@ -125,7 +128,7 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -524,7 +527,7 @@ class TorchAoTest(unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
_ = pipe(**inputs)
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.9.0")
def test_aobase_config(self):
quantization_config = TorchAoConfig(Int8WeightOnlyConfig())
components = self.get_dummy_components(quantization_config)
@@ -537,7 +540,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -647,7 +650,7 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_type, expected_slice, device)
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
@@ -693,7 +696,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -851,7 +854,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.15.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):