Compare commits

...

2 Commits

Author SHA1 Message Date
sayakpaul
119daffdf1 fix 2026-03-30 16:05:22 +05:30
sayakpaul
c346ad5eb8 refactor autoencoderdc tests 2026-03-30 15:30:18 +05:30

View File

@@ -13,24 +13,29 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import AutoencoderDC
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderDC
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderDC
def get_autoencoder_dc_config(self):
@property
def output_shape(self):
return (3, 32, 32)
def get_init_dict(self):
return {
"in_channels": 3,
"latent_channels": 4,
@@ -56,33 +61,34 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
"scaling_factor": 0.41407,
}
@property
def dummy_input(self):
def get_dummy_inputs(self, seed=0):
torch.manual_seed(seed)
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
image = torch.randn(batch_size, num_channels, *sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
# Bridge for AutoencoderTesterMixin which still uses the old interface
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_dc_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
return self.get_init_dict(), self.get_dummy_inputs()
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_inference(self):
super().test_layerwise_casting_inference()
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
base_precision = 1e-2
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderDC."""
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderDC."""
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
def test_layerwise_casting_memory(self):
super().test_layerwise_casting_memory()
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, AutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderDC."""