Compare commits

..

15 Commits

Author SHA1 Message Date
Sayak Paul
cb7402e2ee Merge branch 'main' into fix-torchao-groupoffloading 2026-04-03 10:18:42 +05:30
Sayak Paul
baddc2846c Merge branch 'main' into fix-torchao-groupoffloading 2026-04-01 08:07:44 +05:30
sayakpaul
f60afe5cba error out for the offload to disk option. 2026-03-30 13:19:12 +05:30
Sayak Paul
06509796dd Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:48:22 +05:30
Sayak Paul
59c1b2534a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 11:24:51 +05:30
sayakpaul
7eaeb99fcd address feedback. 2026-03-30 11:24:40 +05:30
Sayak Paul
867192364c Merge branch 'main' into fix-torchao-groupoffloading 2026-03-30 10:53:47 +05:30
Sayak Paul
a8cef0740a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-27 21:16:15 +05:30
Sayak Paul
70067734a2 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-26 11:29:51 +05:30
Sayak Paul
6125a4f540 Merge branch 'main' into fix-torchao-groupoffloading 2026-03-25 08:07:01 +05:30
Sayak Paul
d2666a9d0a Merge branch 'main' into fix-torchao-groupoffloading 2026-03-24 09:06:42 +05:30
sayakpaul
9b9e2e17a6 up 2026-03-23 11:22:36 +05:30
sayakpaul
1a959dc26f switch to swap_tensors. 2026-03-23 10:56:16 +05:30
Sayak Paul
8797398d3b Merge branch 'main' into fix-torchao-groupoffloading 2026-03-23 09:05:37 +05:30
sayakpaul
019a9deafb fix group offloading when using torchao 2026-03-17 10:40:03 +05:30
7 changed files with 195 additions and 241 deletions

View File

@@ -178,12 +178,13 @@ else:
]
)
_import_structure["image_processor"] = [
"InpaintProcessor",
"IPAdapterMaskProcessor",
"InpaintProcessor",
"PixArtImageProcessor",
"VaeImageProcessor",
"VaeImageProcessorLDM3D",
]
_import_structure["video_processor"] = ["VideoProcessor"]
_import_structure["models"].extend(
[
"AllegroTransformer3DModel",
@@ -395,7 +396,6 @@ else:
]
)
_import_structure["training_utils"] = ["EMAModel"]
_import_structure["video_processor"] = ["VideoProcessor"]
try:
if not (is_torch_available() and is_scipy_available()):

View File

@@ -22,7 +22,7 @@ from typing import Set
import safetensors.torch
import torch
from ..utils import get_logger, is_accelerate_available
from ..utils import get_logger, is_accelerate_available, is_torchao_available
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
from .hooks import HookRegistry, ModelHook
@@ -35,6 +35,54 @@ if is_accelerate_available():
logger = get_logger(__name__) # pylint: disable=invalid-name
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
if not is_torchao_available():
return False
from torchao.utils import TorchAOBaseTensor
return isinstance(tensor, TorchAOBaseTensor)
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
cls = type(tensor)
names = list(getattr(cls, "tensor_data_names", []))
for attr_name in getattr(cls, "optional_tensor_data_names", []):
if getattr(tensor, attr_name, None) is not None:
names.append(attr_name)
return names
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
that any dict keyed by `id(param)` remains valid.
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
"""
torch.utils.swap_tensors(param, source)
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
`cpu_param_dict`).
"""
for attr_name in _get_torchao_inner_tensor_names(source):
setattr(param, attr_name, getattr(source, attr_name))
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
"""Record stream for all internal tensors of a TorchAO parameter."""
for attr_name in _get_torchao_inner_tensor_names(param):
getattr(param, attr_name).record_stream(stream)
# fmt: off
_GROUP_OFFLOADING = "group_offloading"
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
@@ -124,6 +172,13 @@ class ModuleGroup:
else torch.cuda
)
@staticmethod
def _to_cpu(tensor, low_cpu_mem_usage):
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
return t if low_cpu_mem_usage else t.pin_memory()
def _init_cpu_param_dict(self):
cpu_param_dict = {}
if self.stream is None:
@@ -131,17 +186,15 @@ class ModuleGroup:
for module in self.modules:
for param in module.parameters():
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in module.buffers():
cpu_param_dict[buffer] = (
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
)
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
for param in self.parameters:
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
for buffer in self.buffers:
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
return cpu_param_dict
@@ -157,9 +210,16 @@ class ModuleGroup:
pinned_dict = None
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
if _is_torchao_tensor(tensor):
_swap_torchao_tensor(tensor, moved)
else:
tensor.data = moved
if self.record_stream:
tensor.data.record_stream(default_stream)
if _is_torchao_tensor(tensor):
_record_stream_torchao_tensor(tensor, default_stream)
else:
tensor.data.record_stream(default_stream)
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
for group_module in self.modules:
@@ -178,7 +238,19 @@ class ModuleGroup:
source = pinned_memory[buffer] if pinned_memory else buffer.data
self._transfer_tensor_to_device(buffer, source, default_stream)
def _check_disk_offload_torchao(self):
all_tensors = list(self.tensor_to_key.keys())
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
if has_torchao:
raise ValueError(
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
"setting `offload_to_disk_path`."
)
def _onload_from_disk(self):
self._check_disk_offload_torchao()
if self.stream is not None:
# Wait for previous Host->Device transfer to complete
self.stream.synchronize()
@@ -221,6 +293,8 @@ class ModuleGroup:
self._process_tensors_from_modules(None)
def _offload_to_disk(self):
self._check_disk_offload_torchao()
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
@@ -245,18 +319,35 @@ class ModuleGroup:
for group_module in self.modules:
for param in group_module.parameters():
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for param in self.parameters:
param.data = self.cpu_param_dict[param]
if _is_torchao_tensor(param):
_restore_torchao_tensor(param, self.cpu_param_dict[param])
else:
param.data = self.cpu_param_dict[param]
for buffer in self.buffers:
buffer.data = self.cpu_param_dict[buffer]
if _is_torchao_tensor(buffer):
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
else:
buffer.data = self.cpu_param_dict[buffer]
else:
for group_module in self.modules:
group_module.to(self.offload_device, non_blocking=False)
for param in self.parameters:
param.data = param.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(param):
moved = param.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(param, moved)
else:
param.data = param.data.to(self.offload_device, non_blocking=False)
for buffer in self.buffers:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
if _is_torchao_tensor(buffer):
moved = buffer.to(self.offload_device, non_blocking=False)
_swap_torchao_tensor(buffer, moved)
else:
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
@torch.compiler.disable()
def onload_(self):

View File

@@ -423,9 +423,7 @@ def dispatch_attention_fn(
**attention_kwargs,
"_parallel_config": parallel_config,
}
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
if _CAN_USE_FLEX_ATTN:
if is_torch_version(">=", "2.5.0"):
kwargs["enable_gqa"] = enable_gqa
if _AttentionBackendRegistry._checks_enabled:

View File

@@ -347,17 +347,7 @@ def lru_cache_unless_export(maxsize=128, typed=False):
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
compiler = getattr(torch, "compiler", None)
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
# Fallback for older builds where compiler.is_compiling is unavailable.
if not is_compiling:
dynamo = getattr(torch, "_dynamo", None)
if dynamo is not None and hasattr(dynamo, "is_compiling"):
is_compiling = dynamo.is_compiling()
if is_exporting or is_compiling:
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)

View File

@@ -14,18 +14,18 @@
# limitations under the License.
import gc
import unittest
import pytest
import torch
from parameterized import parameterized
from diffusers import AutoencoderKL
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
floats_tensor,
load_hf_numpy,
require_torch_accelerator,
require_torch_accelerator_with_fp16,
@@ -35,30 +35,22 @@ from ...testing_utils import (
torch_all_close,
torch_device,
)
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKL
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKL
main_input_name = "sample"
base_precision = 1e-2
@property
def output_shape(self):
return (3, 32, 32)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
block_out_channels = block_out_channels or [2, 4]
norm_num_groups = norm_num_groups or 2
return {
init_dict = {
"block_out_channels": block_out_channels,
"in_channels": 3,
"out_channels": 3,
@@ -67,27 +59,42 @@ class AutoencoderKLTesterConfig(BaseModelTesterConfig):
"latent_channels": 4,
"norm_num_groups": norm_num_groups,
}
return init_dict
def get_dummy_inputs(self):
@property
def dummy_input(self):
batch_size = 4
num_channels = 3
sizes = (32, 32)
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
def test_from_pretrained_hub(self):
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
assert model is not None
assert len(loading_info["missing_keys"]) == 0
self.assertIsNotNone(model)
self.assertEqual(len(loading_info["missing_keys"]), 0)
model.to(torch_device)
image = model(**self.get_dummy_inputs())
image = model(**self.dummy_input)
assert image is not None, "Make sure output is not None"
@@ -161,24 +168,17 @@ class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTes
]
)
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKL."""
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKL."""
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
@slow
class AutoencoderKLIntegrationTests:
class AutoencoderKLIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def teardown_method(self):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@@ -341,7 +341,10 @@ class AutoencoderKLIntegrationTests:
@parameterized.expand([(13,), (16,), (27,)])
@require_torch_gpu
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
model = self.get_sd_vae_model(fp16=True)
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
@@ -359,7 +362,10 @@ class AutoencoderKLIntegrationTests:
@parameterized.expand([(13,), (16,), (37,)])
@require_torch_gpu
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
@unittest.skipIf(
not is_xformers_available(),
reason="xformers is not required when using PyTorch 2.0.",
)
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
model = self.get_sd_vae_model()
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))

View File

@@ -13,34 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
import unittest
from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLWan
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
@property
def output_shape(self):
return (3, 9, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
def get_autoencoder_kl_wan_config(self):
return {
"base_dim": 3,
"z_dim": 16,
@@ -49,40 +39,54 @@ class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
"temperal_downsample": [False, True, True],
}
def get_dummy_inputs(self):
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def dummy_input_tiling(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (128, 128)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
base_precision = 1e-2
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLWan."""
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_wan_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
def prepare_init_args_and_inputs_for_tiling(self):
init_dict = self.get_autoencoder_kl_wan_config()
inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict
@unittest.skip("Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLWan."""
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_memory(self):
@unittest.skip("Test not supported")
def test_forward_with_norm_groups(self):
pass
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_inference(self):
pass
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_training(self):
pass
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLWan."""

View File

@@ -145,138 +145,3 @@ class AutoencoderTesterMixin:
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
class NewAutoencoderTesterMixin:
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
with torch.no_grad():
torch.manual_seed(0)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
if isinstance(output_without_tiling, DecoderOutput):
output_without_tiling = output_without_tiling.sample
torch.manual_seed(0)
model.enable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
if isinstance(output_with_tiling, DecoderOutput):
output_with_tiling = output_with_tiling.sample
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
"VAE tiling should not affect the inference results"
)
torch.manual_seed(0)
model.disable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
if isinstance(output_without_tiling_2, DecoderOutput):
output_without_tiling_2 = output_without_tiling_2.sample
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
"Without tiling outputs should match with the outputs when tiling is manually disabled."
)
def test_enable_disable_slicing(self):
if not hasattr(self.model_class, "enable_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
with torch.no_grad():
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
if isinstance(output_without_slicing, DecoderOutput):
output_without_slicing = output_without_slicing.sample
torch.manual_seed(0)
model.enable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
if isinstance(output_with_slicing, DecoderOutput):
output_with_slicing = output_with_slicing.sample
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
"VAE slicing should not affect the inference results"
)
torch.manual_seed(0)
model.disable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
if isinstance(output_without_slicing_2, DecoderOutput):
output_without_slicing_2 = output_without_slicing_2.sample
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
"Without slicing outputs should match with the outputs when slicing is manually disabled."
)