mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-03 22:31:46 +08:00
Compare commits
15 Commits
autoencode
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb7402e2ee | ||
|
|
baddc2846c | ||
|
|
f60afe5cba | ||
|
|
06509796dd | ||
|
|
59c1b2534a | ||
|
|
7eaeb99fcd | ||
|
|
867192364c | ||
|
|
a8cef0740a | ||
|
|
70067734a2 | ||
|
|
6125a4f540 | ||
|
|
d2666a9d0a | ||
|
|
9b9e2e17a6 | ||
|
|
1a959dc26f | ||
|
|
8797398d3b | ||
|
|
019a9deafb |
@@ -178,12 +178,13 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["image_processor"] = [
|
||||
"InpaintProcessor",
|
||||
"IPAdapterMaskProcessor",
|
||||
"InpaintProcessor",
|
||||
"PixArtImageProcessor",
|
||||
"VaeImageProcessor",
|
||||
"VaeImageProcessorLDM3D",
|
||||
]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
_import_structure["models"].extend(
|
||||
[
|
||||
"AllegroTransformer3DModel",
|
||||
@@ -395,7 +396,6 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["training_utils"] = ["EMAModel"]
|
||||
_import_structure["video_processor"] = ["VideoProcessor"]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_scipy_available()):
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import Set
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger, is_accelerate_available
|
||||
from ..utils import get_logger, is_accelerate_available, is_torchao_available
|
||||
from ._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
@@ -35,6 +35,54 @@ if is_accelerate_available():
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _is_torchao_tensor(tensor: torch.Tensor) -> bool:
|
||||
if not is_torchao_available():
|
||||
return False
|
||||
from torchao.utils import TorchAOBaseTensor
|
||||
|
||||
return isinstance(tensor, TorchAOBaseTensor)
|
||||
|
||||
|
||||
def _get_torchao_inner_tensor_names(tensor: torch.Tensor) -> list[str]:
|
||||
"""Get names of all internal tensor data attributes from a TorchAO tensor."""
|
||||
cls = type(tensor)
|
||||
names = list(getattr(cls, "tensor_data_names", []))
|
||||
for attr_name in getattr(cls, "optional_tensor_data_names", []):
|
||||
if getattr(tensor, attr_name, None) is not None:
|
||||
names.append(attr_name)
|
||||
return names
|
||||
|
||||
|
||||
def _swap_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Move a TorchAO parameter to the device of `source` via `swap_tensors`.
|
||||
|
||||
`param.data = source` does not work for `_make_wrapper_subclass` tensors because the `.data` setter only replaces
|
||||
the outer wrapper storage while leaving the subclass's internal attributes (e.g. `.qdata`, `.scale`) on the
|
||||
original device. `swap_tensors` swaps the full tensor contents in-place, preserving the parameter's identity so
|
||||
that any dict keyed by `id(param)` remains valid.
|
||||
|
||||
Refer to https://github.com/huggingface/diffusers/pull/13276#discussion_r2944471548 for the full discussion.
|
||||
"""
|
||||
torch.utils.swap_tensors(param, source)
|
||||
|
||||
|
||||
def _restore_torchao_tensor(param: torch.Tensor, source: torch.Tensor) -> None:
|
||||
"""Restore internal tensor data of a TorchAO parameter from `source` without mutating `source`.
|
||||
|
||||
Unlike `_swap_torchao_tensor` this copies attribute references one-by-one via `setattr` so that `source` is **not**
|
||||
modified. Use this when `source` is a cached tensor that must remain unchanged (e.g. a pinned CPU copy in
|
||||
`cpu_param_dict`).
|
||||
"""
|
||||
for attr_name in _get_torchao_inner_tensor_names(source):
|
||||
setattr(param, attr_name, getattr(source, attr_name))
|
||||
|
||||
|
||||
def _record_stream_torchao_tensor(param: torch.Tensor, stream) -> None:
|
||||
"""Record stream for all internal tensors of a TorchAO parameter."""
|
||||
for attr_name in _get_torchao_inner_tensor_names(param):
|
||||
getattr(param, attr_name).record_stream(stream)
|
||||
|
||||
|
||||
# fmt: off
|
||||
_GROUP_OFFLOADING = "group_offloading"
|
||||
_LAYER_EXECUTION_TRACKER = "layer_execution_tracker"
|
||||
@@ -124,6 +172,13 @@ class ModuleGroup:
|
||||
else torch.cuda
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _to_cpu(tensor, low_cpu_mem_usage):
|
||||
# For TorchAO tensors, `.data` returns an incomplete wrapper without internal attributes
|
||||
# (e.g. `.qdata`, `.scale`), so we must call `.cpu()` on the tensor directly.
|
||||
t = tensor.cpu() if _is_torchao_tensor(tensor) else tensor.data.cpu()
|
||||
return t if low_cpu_mem_usage else t.pin_memory()
|
||||
|
||||
def _init_cpu_param_dict(self):
|
||||
cpu_param_dict = {}
|
||||
if self.stream is None:
|
||||
@@ -131,17 +186,15 @@ class ModuleGroup:
|
||||
|
||||
for module in self.modules:
|
||||
for param in module.parameters():
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
for buffer in module.buffers():
|
||||
cpu_param_dict[buffer] = (
|
||||
buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
)
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
|
||||
for param in self.parameters:
|
||||
cpu_param_dict[param] = param.data.cpu() if self.low_cpu_mem_usage else param.data.cpu().pin_memory()
|
||||
cpu_param_dict[param] = self._to_cpu(param, self.low_cpu_mem_usage)
|
||||
|
||||
for buffer in self.buffers:
|
||||
cpu_param_dict[buffer] = buffer.data.cpu() if self.low_cpu_mem_usage else buffer.data.cpu().pin_memory()
|
||||
cpu_param_dict[buffer] = self._to_cpu(buffer, self.low_cpu_mem_usage)
|
||||
|
||||
return cpu_param_dict
|
||||
|
||||
@@ -157,9 +210,16 @@ class ModuleGroup:
|
||||
pinned_dict = None
|
||||
|
||||
def _transfer_tensor_to_device(self, tensor, source_tensor, default_stream):
|
||||
tensor.data = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
moved = source_tensor.to(self.onload_device, non_blocking=self.non_blocking)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_swap_torchao_tensor(tensor, moved)
|
||||
else:
|
||||
tensor.data = moved
|
||||
if self.record_stream:
|
||||
tensor.data.record_stream(default_stream)
|
||||
if _is_torchao_tensor(tensor):
|
||||
_record_stream_torchao_tensor(tensor, default_stream)
|
||||
else:
|
||||
tensor.data.record_stream(default_stream)
|
||||
|
||||
def _process_tensors_from_modules(self, pinned_memory=None, default_stream=None):
|
||||
for group_module in self.modules:
|
||||
@@ -178,7 +238,19 @@ class ModuleGroup:
|
||||
source = pinned_memory[buffer] if pinned_memory else buffer.data
|
||||
self._transfer_tensor_to_device(buffer, source, default_stream)
|
||||
|
||||
def _check_disk_offload_torchao(self):
|
||||
all_tensors = list(self.tensor_to_key.keys())
|
||||
has_torchao = any(_is_torchao_tensor(t) for t in all_tensors)
|
||||
if has_torchao:
|
||||
raise ValueError(
|
||||
"Disk offloading is not supported for TorchAO quantized tensors because safetensors "
|
||||
"cannot serialize TorchAO subclass tensors. Use memory offloading instead by not "
|
||||
"setting `offload_to_disk_path`."
|
||||
)
|
||||
|
||||
def _onload_from_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
if self.stream is not None:
|
||||
# Wait for previous Host->Device transfer to complete
|
||||
self.stream.synchronize()
|
||||
@@ -221,6 +293,8 @@ class ModuleGroup:
|
||||
self._process_tensors_from_modules(None)
|
||||
|
||||
def _offload_to_disk(self):
|
||||
self._check_disk_offload_torchao()
|
||||
|
||||
# TODO: we can potentially optimize this code path by checking if the _all_ the desired
|
||||
# safetensor files exist on the disk and if so, skip this step entirely, reducing IO
|
||||
# overhead. Currently, we just check if the given `safetensors_file_path` exists and if not
|
||||
@@ -245,18 +319,35 @@ class ModuleGroup:
|
||||
|
||||
for group_module in self.modules:
|
||||
for param in group_module.parameters():
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for param in self.parameters:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
if _is_torchao_tensor(param):
|
||||
_restore_torchao_tensor(param, self.cpu_param_dict[param])
|
||||
else:
|
||||
param.data = self.cpu_param_dict[param]
|
||||
for buffer in self.buffers:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
if _is_torchao_tensor(buffer):
|
||||
_restore_torchao_tensor(buffer, self.cpu_param_dict[buffer])
|
||||
else:
|
||||
buffer.data = self.cpu_param_dict[buffer]
|
||||
else:
|
||||
for group_module in self.modules:
|
||||
group_module.to(self.offload_device, non_blocking=False)
|
||||
for param in self.parameters:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(param):
|
||||
moved = param.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(param, moved)
|
||||
else:
|
||||
param.data = param.data.to(self.offload_device, non_blocking=False)
|
||||
for buffer in self.buffers:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
if _is_torchao_tensor(buffer):
|
||||
moved = buffer.to(self.offload_device, non_blocking=False)
|
||||
_swap_torchao_tensor(buffer, moved)
|
||||
else:
|
||||
buffer.data = buffer.data.to(self.offload_device, non_blocking=False)
|
||||
|
||||
@torch.compiler.disable()
|
||||
def onload_(self):
|
||||
|
||||
@@ -423,9 +423,7 @@ def dispatch_attention_fn(
|
||||
**attention_kwargs,
|
||||
"_parallel_config": parallel_config,
|
||||
}
|
||||
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
|
||||
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
if is_torch_version(">=", "2.5.0"):
|
||||
kwargs["enable_gqa"] = enable_gqa
|
||||
|
||||
if _AttentionBackendRegistry._checks_enabled:
|
||||
|
||||
@@ -347,17 +347,7 @@ def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
compiler = getattr(torch, "compiler", None)
|
||||
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
|
||||
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
|
||||
|
||||
# Fallback for older builds where compiler.is_compiling is unavailable.
|
||||
if not is_compiling:
|
||||
dynamo = getattr(torch, "_dynamo", None)
|
||||
if dynamo is not None and hasattr(dynamo, "is_compiling"):
|
||||
is_compiling = dynamo.is_compiling()
|
||||
|
||||
if is_exporting or is_compiling:
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -14,18 +14,18 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import (
|
||||
backend_empty_cache,
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
require_torch_accelerator,
|
||||
require_torch_accelerator_with_fp16,
|
||||
@@ -35,30 +35,22 @@ from ...testing_utils import (
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKL
|
||||
class AutoencoderKLTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKL
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self, block_out_channels=None, norm_num_groups=None):
|
||||
def get_autoencoder_kl_config(self, block_out_channels=None, norm_num_groups=None):
|
||||
block_out_channels = block_out_channels or [2, 4]
|
||||
norm_num_groups = norm_num_groups or 2
|
||||
return {
|
||||
init_dict = {
|
||||
"block_out_channels": block_out_channels,
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -67,27 +59,42 @@ class AutoencoderKLTesterConfig(BaseModelTesterConfig):
|
||||
"latent_channels": 4,
|
||||
"norm_num_groups": norm_num_groups,
|
||||
}
|
||||
return init_dict
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Decoder", "Encoder", "UNetMidBlock2D"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
assert model is not None
|
||||
assert len(loading_info["missing_keys"]) == 0
|
||||
self.assertIsNotNone(model)
|
||||
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||
|
||||
model.to(torch_device)
|
||||
image = model(**self.get_dummy_inputs())
|
||||
image = model(**self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
@@ -161,24 +168,17 @@ class TestAutoencoderKL(AutoencoderKLTesterConfig, ModelTesterMixin, TrainingTes
|
||||
]
|
||||
)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, rtol=1e-2)
|
||||
|
||||
|
||||
class TestAutoencoderKLMemory(AutoencoderKLTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKL."""
|
||||
|
||||
|
||||
class TestAutoencoderKLSlicingTiling(AutoencoderKLTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKL."""
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2))
|
||||
|
||||
|
||||
@slow
|
||||
class AutoencoderKLIntegrationTests:
|
||||
class AutoencoderKLIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def teardown_method(self):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@@ -341,7 +341,10 @@ class AutoencoderKLIntegrationTests:
|
||||
|
||||
@parameterized.expand([(13,), (16,), (27,)])
|
||||
@require_torch_gpu
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0_fp16(self, seed):
|
||||
model = self.get_sd_vae_model(fp16=True)
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64), fp16=True)
|
||||
@@ -359,7 +362,10 @@ class AutoencoderKLIntegrationTests:
|
||||
|
||||
@parameterized.expand([(13,), (16,), (37,)])
|
||||
@require_torch_gpu
|
||||
@pytest.mark.skipif(not is_xformers_available(), reason="xformers is not required when using PyTorch 2.0.")
|
||||
@unittest.skipIf(
|
||||
not is_xformers_available(),
|
||||
reason="xformers is not required when using PyTorch 2.0.",
|
||||
)
|
||||
def test_stable_diffusion_decode_xformers_vs_2_0(self, seed):
|
||||
model = self.get_sd_vae_model()
|
||||
encoding = self.get_sd_image(seed, shape=(3, 4, 64, 64))
|
||||
|
||||
@@ -13,34 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLWan
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
|
||||
from .testing_utils import NewAutoencoderTesterMixin
|
||||
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return AutoencoderKLWan
|
||||
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLWan
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self):
|
||||
def get_autoencoder_kl_wan_config(self):
|
||||
return {
|
||||
"base_dim": 3,
|
||||
"z_dim": 16,
|
||||
@@ -49,40 +39,54 @@ class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
image = randn_tensor(
|
||||
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
|
||||
)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
@property
|
||||
def dummy_input_tiling(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (128, 128)
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
return {"sample": image}
|
||||
|
||||
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
|
||||
base_precision = 1e-2
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for AutoencoderKLWan."""
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
|
||||
def prepare_init_args_and_inputs_for_tiling(self):
|
||||
init_dict = self.get_autoencoder_kl_wan_config()
|
||||
inputs_dict = self.dummy_input_tiling
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skip("Gradient checkpointing has not been implemented yet")
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for AutoencoderKLWan."""
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_memory(self):
|
||||
@unittest.skip("Test not supported")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
|
||||
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_inference(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
|
||||
def test_layerwise_casting_training(self):
|
||||
pass
|
||||
|
||||
|
||||
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
|
||||
"""Slicing and tiling tests for AutoencoderKLWan."""
|
||||
|
||||
@@ -145,138 +145,3 @@ class AutoencoderTesterMixin:
|
||||
output_without_slicing.detach().cpu().numpy().all(),
|
||||
output_without_slicing_2.detach().cpu().numpy().all(),
|
||||
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
|
||||
|
||||
class NewAutoencoderTesterMixin:
|
||||
@staticmethod
|
||||
def _accepts_generator(model):
|
||||
model_sig = inspect.signature(model.forward)
|
||||
accepts_generator = "generator" in model_sig.parameters
|
||||
return accepts_generator
|
||||
|
||||
@staticmethod
|
||||
def _accepts_norm_num_groups(model_class):
|
||||
model_sig = inspect.signature(model_class.__init__)
|
||||
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
|
||||
return accepts_norm_groups
|
||||
|
||||
def test_forward_with_norm_groups(self):
|
||||
if not self._accepts_norm_num_groups(self.model_class):
|
||||
pytest.skip(f"Test not supported for {self.model_class.__name__}")
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict["norm_num_groups"] = 16
|
||||
init_dict["block_out_channels"] = (16, 32)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
|
||||
assert output is not None
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
assert output.shape == expected_shape, "Input and output shapes do not match"
|
||||
|
||||
def test_enable_disable_tiling(self):
|
||||
if not hasattr(self.model_class, "enable_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
if not hasattr(model, "use_tiling"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
torch.manual_seed(0)
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling, DecoderOutput):
|
||||
output_without_tiling = output_without_tiling.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_tiling = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_tiling, DecoderOutput):
|
||||
output_with_tiling = output_with_tiling.sample
|
||||
|
||||
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
|
||||
"VAE tiling should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_tiling()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_tiling_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_tiling_2, DecoderOutput):
|
||||
output_without_tiling_2 = output_without_tiling_2.sample
|
||||
|
||||
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
|
||||
"Without tiling outputs should match with the outputs when tiling is manually disabled."
|
||||
)
|
||||
|
||||
def test_enable_disable_slicing(self):
|
||||
if not hasattr(self.model_class, "enable_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
if not hasattr(model, "use_slicing"):
|
||||
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
|
||||
|
||||
inputs_dict.update({"return_dict": False})
|
||||
_ = inputs_dict.pop("generator", None)
|
||||
accepts_generator = self._accepts_generator(model)
|
||||
|
||||
with torch.no_grad():
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
|
||||
torch.manual_seed(0)
|
||||
output_without_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing, DecoderOutput):
|
||||
output_without_slicing = output_without_slicing.sample
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.enable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_with_slicing = model(**inputs_dict)[0]
|
||||
if isinstance(output_with_slicing, DecoderOutput):
|
||||
output_with_slicing = output_with_slicing.sample
|
||||
|
||||
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
|
||||
"VAE slicing should not affect the inference results"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model.disable_slicing()
|
||||
if accepts_generator:
|
||||
inputs_dict["generator"] = torch.manual_seed(0)
|
||||
output_without_slicing_2 = model(**inputs_dict)[0]
|
||||
if isinstance(output_without_slicing_2, DecoderOutput):
|
||||
output_without_slicing_2 = output_without_slicing_2.sample
|
||||
|
||||
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
|
||||
"Without slicing outputs should match with the outputs when slicing is manually disabled."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user