mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
13 Commits
progress-b
...
vae-tests-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
137bf5af89 | ||
|
|
7ba8d9a238 | ||
|
|
74160ed00f | ||
|
|
1ac55e7a7e | ||
|
|
7b8817ec04 | ||
|
|
9a6ecfbcc4 | ||
|
|
6a01c4681c | ||
|
|
3a106f05ee | ||
|
|
378090705f | ||
|
|
769b7452ed | ||
|
|
01aa188d8d | ||
|
|
490c4761b4 | ||
|
|
cfe1e2e3fa |
@@ -35,13 +35,14 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AsymmetricAutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AsymmetricAutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@@ -17,13 +17,14 @@ import unittest
|
||||
from diffusers import AutoencoderKLCosmos
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLCosmos
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -80,7 +81,3 @@ class AutoencoderKLCosmosTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
|
||||
@unittest.skip("Not sure why this test fails. Investigate later.")
|
||||
def test_effective_gradient_checkpointing(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@@ -22,13 +22,14 @@ from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -81,7 +82,3 @@ class AutoencoderDCTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("AutoencoderDC does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@@ -20,18 +20,15 @@ import torch
|
||||
from diffusers import AutoencoderKLHunyuanVideo
|
||||
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLHunyuanVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -87,68 +84,6 @@ class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, UNetTesterMixin, unittest
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"HunyuanVideoDecoder3D",
|
||||
|
||||
@@ -35,13 +35,14 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -83,68 +84,6 @@ class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@@ -24,13 +24,14 @@ from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLCogVideoXTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLCogVideoX
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -82,68 +83,6 @@ class AutoencoderKLCogVideoXTests(ModelTesterMixin, UNetTesterMixin, unittest.Te
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"CogVideoXDownBlock3D",
|
||||
|
||||
@@ -22,13 +22,14 @@ from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLTemporalDecoder
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -67,7 +68,3 @@ class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, UNetTesterMixin, unitt
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@@ -24,13 +24,14 @@ from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -99,7 +100,7 @@ class AutoencoderKLLTXVideo090Tests(ModelTesterMixin, UNetTesterMixin, unittest.
|
||||
pass
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTXVideo
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -167,34 +168,3 @@ class AutoencoderKLLTXVideo091Tests(ModelTesterMixin, UNetTesterMixin, unittest.
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
@@ -18,13 +18,14 @@ import unittest
|
||||
from diffusers import AutoencoderKLMagvit
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLMagvitTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLMagvit
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -88,3 +89,9 @@ class AutoencoderKLMagvitTests(ModelTesterMixin, UNetTesterMixin, unittest.TestC
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
"Unsupported test. Error: RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 9 but got size 12 for tensor number 1 in the list."
|
||||
)
|
||||
def test_enable_disable_slicing(self):
|
||||
pass
|
||||
|
||||
@@ -17,18 +17,15 @@ import unittest
|
||||
|
||||
from diffusers import AutoencoderKLMochi
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLMochiTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLMochi
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -79,14 +76,6 @@ class AutoencoderKLMochiTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
"""
|
||||
tests/models/autoencoders/test_models_autoencoder_mochi.py::AutoencoderKLMochiTests::test_forward_with_norm_groups -
|
||||
TypeError: AutoencoderKLMochi.__init__() got an unexpected keyword argument 'norm_num_groups'
|
||||
"""
|
||||
pass
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_model_parallelism(self):
|
||||
"""
|
||||
|
||||
@@ -30,13 +30,14 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderOobleckTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderOobleck
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -106,10 +107,6 @@ class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCa
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
@unittest.skip("Test unsupported.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("No attention module used in this model")
|
||||
def test_set_attn_processor_for_determinism(self):
|
||||
return
|
||||
|
||||
@@ -31,13 +31,14 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderTinyTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderTiny
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -81,37 +82,6 @@ class AutoencoderTinyTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
|
||||
def test_enable_disable_tiling(self):
|
||||
pass
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
@unittest.skip("Test not supported.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@@ -15,18 +15,17 @@
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -76,68 +75,6 @@ class AutoencoderKLWanTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_tiling()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling(96, 96, 64, 64)
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.05,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
@@ -31,12 +31,13 @@ from ...testing_utils import (
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
class ConsistencyDecoderVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = ConsistencyDecoderVAE
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
@@ -92,70 +93,6 @@ class ConsistencyDecoderVAETests(ModelTesterMixin, unittest.TestCase):
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return self.init_dict, self.inputs_dict()
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
output_with_tiling = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE tiling should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
output_without_tiling_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled.",
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator")
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
output_with_slicing = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertLess(
|
||||
(output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()).max(),
|
||||
0.5,
|
||||
"VAE slicing should not affect the inference results",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
output_without_slicing_2 = model(**inputs_dict, generator=torch.manual_seed(0))[0]
|
||||
|
||||
self.assertEqual(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled.",
|
||||
)
|
||||
|
||||
|
||||
@slow
|
||||
class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
|
||||
@@ -19,19 +19,15 @@ import torch
|
||||
|
||||
from diffusers import VQModel
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_manual_seed,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin
|
||||
from ...testing_utils import backend_manual_seed, enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class VQModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
|
||||
class VQModelTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = VQModel
|
||||
main_input_name = "sample"
|
||||
|
||||
|
||||
142
tests/models/autoencoders/testing_utils.py
Normal file
142
tests/models/autoencoders/testing_utils.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.autoencoders.vae import DecoderOutput
|
||||
from diffusers.utils.torch_utils import torch_device
|
||||
|
||||
|
||||
class AutoencoderTesterMixin:
|
||||
"""
|
||||
Test mixin class specific to VAEs to test for slicing and tiling. Diffusion networks
|
||||
usually don't do slicing and tiling.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _accepts_generator(model):
|
||||
model_sig = inspect.signature(model.forward)
|
||||
accepts_generator = "generator" in model_sig.parameters
|
||||
return accepts_generator
|
||||
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
torch.manual_seed(0)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict)[0]
|
||||
# Mochi-1
|
||||
if isinstance(output_without_tiling, DecoderOutput):
|
||||
output_without_tiling = output_without_tiling.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_tiling, DecoderOutput):
|
||||
output_with_tiling = output_with_tiling.sample
|
||||
|
||||
assert (
|
||||
output_without_tiling.detach().cpu().numpy() - output_with_tiling.detach().cpu().numpy()
|
||||
).max() < 0.5, "VAE tiling should not affect the inference results"
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling_2, DecoderOutput):
|
||||
output_without_tiling_2 = output_without_tiling_2.sample
|
||||
|
||||
assert np.allclose(
|
||||
output_without_tiling.detach().cpu().numpy().all(),
|
||||
output_without_tiling_2.detach().cpu().numpy().all(),
|
||||
), "Without tiling outputs should match with the outputs when tiling is manually disabled."
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
if not hasattr(self.model_class, "enable_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
# Mochi-1
|
||||
if isinstance(output_without_slicing, DecoderOutput):
|
||||
output_without_slicing = output_without_slicing.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_slicing, DecoderOutput):
|
||||
output_with_slicing = output_with_slicing.sample
|
||||
|
||||
assert (
|
||||
output_without_slicing.detach().cpu().numpy() - output_with_slicing.detach().cpu().numpy()
|
||||
).max() < 0.5, "VAE slicing should not affect the inference results"
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing_2, DecoderOutput):
|
||||
output_without_slicing_2 = output_without_slicing_2.sample
|
||||
|
||||
assert np.allclose(
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
@@ -450,7 +450,15 @@ class ModelUtilsTest(unittest.TestCase):
|
||||
|
||||
|
||||
class UNetTesterMixin:
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
|
||||
Reference in New Issue
Block a user