Compare commits

...

4 Commits

Author SHA1 Message Date
sayakpaul
771174ac68 confirm coverage 2026-03-30 15:41:14 +05:30
Sayak Paul
d37402866c Merge branch 'main' into autoencoderkl-tests-refactor 2026-03-30 15:19:52 +05:30
sayakpaul
d10190b1b7 fix tests 2026-03-30 14:58:08 +05:30
sayakpaul
a1c3e6ccbb refactor autoencoderkl tests 2026-03-30 13:33:26 +05:30

View File

@@ -14,8 +14,8 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from parameterized import parameterized
@@ -25,7 +25,6 @@ from diffusers.utils.import_utils import is_xformers_available
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
@@ -35,22 +34,26 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKL
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
@property
def output_shape(self):
return (3, 32, 32)
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
init_dict = {
return {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
@@ -59,42 +62,32 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
@property
def dummy_input(self):
def get_dummy_inputs(self, seed=0):
torch.manual_seed(seed)
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
image = torch.randn(batch_size, num_channels, *sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
# Bridge for AutoencoderTesterMixin which still uses the old interface
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
return self.get_init_dict(), self.get_dummy_inputs()
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
model.to(torch_device)
image = model(**self.dummy_input)
image = model(**self.get_dummy_inputs())
assert image is not None, "Make sure output is not None"
@@ -168,17 +161,24 @@ class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
]
)
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKL."""
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, AutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKL."""
@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
class AutoencoderKLIntegrationTests:
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
def teardown_method(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@@ -341,10 +341,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -362,10 +359,7 @@ class AutoencoderKLIntegrationTests(unittest.TestCase):
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))