mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-06 15:52:07 +08:00
Compare commits
4 Commits
autoencode
...
autoencode
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7db9646053 | ||
|
|
760918e8a5 | ||
|
|
119daffdf1 | ||
|
|
c346ad5eb8 |
@@ -178,13 +178,12 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["image_processor"] = [
|
||||
"IPAdapterMaskProcessor",
|
||||
"InpaintProcessor",
|
||||
"IPAdapterMaskProcessor",
|
||||
"PixArtImageProcessor",
|
||||
"VaeImageProcessor",
|
||||
"VaeImageProcessorLDM3D",
|
||||
]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -396,6 +395,7 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
|
||||
@@ -13,24 +13,34 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderDC
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
from ...testing_utils import IS_GITHUB_ACTIONS, enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderDC
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
class AutoencoderDCTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderDC
|
||||
|
||||
def get_autoencoder_dc_config(self):
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"latent_channels": 4,
|
||||
@@ -56,33 +66,29 @@ class AutoencoderDCTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.Test
|
||||
"scaling_factor": 0.41407,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
def get_dummy_inputs(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
class TestAutoencoderDC(AutoencoderDCTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_dc_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_inference(self):
|
||||
super().test_layerwise_casting_inference()
|
||||
class TestAutoencoderDCTraining(AutoencoderDCTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderDC."""
|
||||
|
||||
@unittest.skipIf(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
|
||||
class TestAutoencoderDCMemory(AutoencoderDCTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderDC."""
|
||||
|
||||
@pytest.mark.skipif(IS_GITHUB_ACTIONS, reason="Skipping test inside GitHub Actions environment")
|
||||
def test_layerwise_casting_memory(self):
|
||||
super().test_layerwise_casting_memory()
|
||||
|
||||
|
||||
class TestAutoencoderDCSlicingTiling(AutoencoderDCTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderDC."""
|
||||
|
||||
Reference in New Issue
Block a user