Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
137bf5af89 up 2025-10-17 11:11:01 +05:30
Sayak Paul
7ba8d9a238 Merge branch 'main' into vae-tests-mixin 2025-10-17 10:44:09 +05:30
sayakpaul
74160ed00f up 2025-10-17 10:43:52 +05:30
Sayak Paul
1ac55e7a7e Merge branch 'main' into vae-tests-mixin 2025-10-17 08:00:37 +05:30
sayakpaul
7b8817ec04 up 2025-09-24 10:44:15 +05:30
Sayak Paul
9a6ecfbcc4 Merge branch 'main' into vae-tests-mixin 2025-09-24 09:43:58 +05:30
sayakpaul
6a01c4681c u[ 2025-09-24 09:30:34 +05:30
sayakpaul
3a106f05ee up 2025-09-24 09:20:42 +05:30
Sayak Paul
378090705f Merge branch 'main' into vae-tests-mixin 2025-09-24 09:17:19 +05:30
sayakpaul
769b7452ed up 2025-09-24 09:15:50 +05:30
sayakpaul
01aa188d8d up 2025-09-23 14:05:51 +05:30
sayakpaul
490c4761b4 up 2025-09-23 13:45:27 +05:30
sayakpaul
cfe1e2e3fa up 2025-09-23 13:32:05 +05:30
17 changed files with 204 additions and 446 deletions

View File

@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AsymmetricAutoencoderKL
main_input_name = "sample"
base_precision = 1e-2

View File

@@ -17,13 +17,14 @@ import unittest
from diffusers import AutoencoderKLCosmos
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -20,18 +20,15 @@ import torch
from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"HunyuanVideoDecoder3D",

View File

@@ -35,13 +35,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCogVideoX
main_input_name = "sample"
base_precision = 1e-2
@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CogVideoXDownBlock3D",

View File

@@ -22,13 +22,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2
@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -24,13 +24,14 @@ from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.
pass
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTXVideo
main_input_name = "sample"
base_precision = 1e-2
@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)

View File

@@ -18,13 +18,14 @@ import unittest
from diffusers import AutoencoderKLMagvit
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMagvit
main_input_name = "sample"
base_precision = 1e-2
@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip(
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
)
def test_enable_disable_slicing(self):
pass

View File

@@ -17,18 +17,15 @@ import unittest
from diffusers import AutoencoderKLMochi
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLMochi
main_input_name = "sample"
base_precision = 1e-2
@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_forward_with_norm_groups(self):
"""
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
"""
pass
@unittest.skip("Unsupported test.")
def test_model_parallelism(self):
"""

View File

@@ -30,13 +30,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderOobleck
main_input_name = "sample"
base_precision = 1e-2
@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test unsupported.")
def test_forward_with_norm_groups(self):
pass
@unittest.skip("No attention module used in this model")
def test_set_attn_processor_for_determinism(self):
return

View File

@@ -31,13 +31,14 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderTiny
main_input_name = "sample"
base_precision = 1e-2
@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
def test_enable_disable_tiling(self):
pass
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict)[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict)[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Test not supported.")
def test_outputs_equivalence(self):
pass

View File

@@ -15,18 +15,17 @@
import unittest
import torch
from diffusers import AutoencoderKLWan
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling(96, 96, 64, 64)
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.05,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass

View File

@@ -31,12 +31,13 @@ from ...testing_utils import (
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = ConsistencyDecoderVAE
main_input_name = "sample"
base_precision = 1e-2
@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
def prepare_init_args_and_inputs_for_common(self):
return self.init_dict, self.inputs_dict()
def test_enable_disable_tiling(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_tiling()
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
0.5,
"VAE tiling should not affect the inference results",
)
torch.manual_seed(0)
model.disable_tiling()
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
)
def test_enable_disable_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator")
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
torch.manual_seed(0)
model.enable_slicing()
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertLess(
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
0.5,
"VAE slicing should not affect the inference results",
)
torch.manual_seed(0)
model.disable_slicing()
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
self.assertEqual(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
)
@slow
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):

View File

@@ -19,19 +19,15 @@ import torch
from diffusers import VQModel
from ...testing_utils import (
backend_manual_seed,
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = VQModel
main_input_name = "sample"

View File

@@ -0,0 +1,142 @@
import inspect
import numpy as np
import pytest
import torch
from diffusers.models.autoencoders.vae import DecoderOutput
from diffusers.utils.torch_utils import torch_device
class AutoencoderTesterMixin:
"""
Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
usually don't do slicing and tiling.
"""
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
torch.manual_seed(0)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_tiling, DecoderOutput):
output_without_tiling = output_without_tiling.sample
torch.manual_seed(0)
model.enable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
if isinstance(output_with_tiling, DecoderOutput):
output_with_tiling = output_with_tiling.sample
assert (
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
).max() < 0.5, "VAE tiling should not affect the inference results"
torch.manual_seed(0)
model.disable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
if isinstance(output_without_tiling_2, DecoderOutput):
output_without_tiling_2 = output_without_tiling_2.sample
assert np.allclose(
output_without_tiling.detach().cpu().numpy().all(),
output_without_tiling_2.detach().cpu().numpy().all(),
), "Without tiling outputs should match with the outputs when tiling is manually disabled."
def test_enable_disable_slicing(self):
if not hasattr(self.model_class, "enable_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
# Mochi-1
if isinstance(output_without_slicing, DecoderOutput):
output_without_slicing = output_without_slicing.sample
torch.manual_seed(0)
model.enable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
if isinstance(output_with_slicing, DecoderOutput):
output_with_slicing = output_with_slicing.sample
assert (
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
).max() < 0.5, "VAE slicing should not affect the inference results"
torch.manual_seed(0)
model.disable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
if isinstance(output_without_slicing_2, DecoderOutput):
output_without_slicing_2 = output_without_slicing_2.sample
assert np.allclose(
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
), "Without slicing outputs should match with the outputs when slicing is manually disabled."

View File

@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
class UNetTesterMixin:
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["norm_num_groups"] = 16