Compare commits

...

10 Commits

Author SHA1 Message Date
Sayak Paul
bbe044acd9 Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-06 15:39:01 +02:00
Sayak Paul
357b681890 [tests] refactor autoencoderdc tests (#13369)
* refactor autoencoderdc tests

* fix

* propagate new changes.
2026-04-06 11:10:21 +02:00
Sayak Paul
4529468ca8 Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-06 10:24:31 +02:00
Dhruv Nair
065e36937a [CI] Refactor Cosmos Transformer Tests (#13335)
update

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-06 10:05:37 +05:30
sayakpaul
5e9da5e998 up 2026-04-03 07:49:08 +02:00
Sayak Paul
bd0af5096e Merge branch 'main' into autoencoderkl-tests-refactor 2026-04-03 11:11:35 +05:30
sayakpaul
771174ac68 confirm coverage 2026-03-30 15:41:14 +05:30
Sayak Paul
d37402866c Merge branch 'main' into autoencoderkl-tests-refactor 2026-03-30 15:19:52 +05:30
sayakpaul
d10190b1b7 fix tests 2026-03-30 14:58:08 +05:30
sayakpaul
a1c3e6ccbb refactor autoencoderkl tests 2026-03-30 13:33:26 +05:30
3 changed files with 175 additions and 146 deletions

View File

@@ -13,24 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import AutoencoderDC
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderDC
def get_autoencoder_dc_config(self):
@property
def output_shape(self):
return (3, 32, 32)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"in_channels": 3,
"latent_channels": 4,
@@ -56,33 +66,29 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
"scaling_factor": 0.41407,
}
@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
base_precision = 1e-2
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderDC."""
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderDC."""
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderDC."""

View File

@@ -14,18 +14,18 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
@@ -35,22 +35,30 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKL
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@property
def output_shape(self):
return (3, 32, 32)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
return {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
@@ -59,42 +67,27 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -168,17 +161,24 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKL."""
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKL."""
@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
class AutoencoderKLIntegrationTests:
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
def teardown_method(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@@ -341,10 +341,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -362,10 +359,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

View File

@@ -12,60 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import CosmosTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import (
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
class CosmosTransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return CosmosTransformer3DModel
@property
def dummy_input(self):
batch_size = 1
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"padding_mask": padding_mask,
}
@property
def input_shape(self):
return (4, 1, 16, 16)
@property
def output_shape(self):
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"num_attention_heads": 2,
@@ -80,57 +66,68 @@ class CosmosTransformer3DModelTests(ModelTesterMixin, unittest.TestCase):
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
model_class = CosmosTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
fps = 30
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"attention_mask": attention_mask,
"fps": fps,
"condition_mask": condition_mask,
"padding_mask": padding_mask,
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformer(CosmosTransformerTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer."""
class TestCosmosTransformerMemory(CosmosTransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer."""
class TestCosmosTransformerTraining(CosmosTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class CosmosTransformerVideoToWorldTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return CosmosTransformer3DModel
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, ...]:
return (4, 1, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list | tuple | float | bool | str]:
return {
"in_channels": 4 + 1,
"out_channels": 4,
"num_attention_heads": 2,
@@ -145,8 +142,40 @@ class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestC
"concat_padding_mask": True,
"extra_pos_embed_type": "learnable",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_channels = 4
num_frames = 1
height = 16
width = 16
text_embed_dim = 16
sequence_length = 12
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, text_embed_dim), generator=self.generator, device=torch_device
),
"attention_mask": torch.ones((batch_size, sequence_length)).to(torch_device),
"fps": 30,
"condition_mask": torch.ones(batch_size, 1, num_frames, height, width).to(torch_device),
"padding_mask": torch.zeros(batch_size, 1, height, width).to(torch_device),
}
class TestCosmosTransformerVideoToWorld(CosmosTransformerVideoToWorldTesterConfig, ModelTesterMixin):
"""Core model tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldMemory(CosmosTransformerVideoToWorldTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Cosmos Transformer (Video-to-World)."""
class TestCosmosTransformerVideoToWorldTraining(CosmosTransformerVideoToWorldTesterConfig, TrainingTesterMixin):
"""Training tests for Cosmos Transformer (Video-to-World)."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"CosmosTransformer3DModel"}