mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-28 07:19:22 +08:00
Compare commits
1 Commits
attention-
...
remove-unn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10dfa9b722 |
@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model for 2D images with spatial tiling support.
|
||||
|
||||
@@ -486,27 +486,6 @@ class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin)
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor):
|
||||
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanImage-2.1 Refiner.
|
||||
@@ -685,27 +685,6 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanVideo-1.5.
|
||||
@@ -723,27 +723,6 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
@@ -11,7 +11,8 @@ export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
|
||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
||||
|
||||
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
||||
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
||||
@@ -23,8 +24,6 @@ import os
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from ..testing_utils import numpy_cosine_similarity_distance
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||
@@ -37,28 +36,23 @@ from diffusers.utils import is_torch_version # noqa: E402
|
||||
FORWARD_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
|
||||
1e-4
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
1e-4
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
|
||||
1e-4
|
||||
),
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||
5e-4
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
||||
1e-4
|
||||
)
|
||||
]
|
||||
|
||||
@@ -66,32 +60,27 @@ COMPILE_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
1e-4
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
1e-4
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||
True,
|
||||
1e-4
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||
True,
|
||||
5e-4,
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
||||
True,
|
||||
1e-4
|
||||
)
|
||||
]
|
||||
# fmt: on
|
||||
@@ -115,11 +104,11 @@ def _backend_is_probably_supported(pipe, name: str):
|
||||
return False
|
||||
|
||||
|
||||
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
|
||||
def _check_if_slices_match(output, expected_slice):
|
||||
img = output.images.detach().cpu()
|
||||
generated_slice = img.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
|
||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
@@ -137,23 +126,23 @@ def pipe(device):
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice, expected_diff):
|
||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice):
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
_check_if_slices_match(out, expected_slice, expected_diff)
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,expected_slice,error_on_recompile,expected_diff",
|
||||
"backend_name,expected_slice,error_on_recompile",
|
||||
COMPILE_CASES,
|
||||
ids=[c[0] for c in COMPILE_CASES],
|
||||
)
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||
|
||||
@@ -171,4 +160,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
|
||||
):
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
|
||||
_check_if_slices_match(out, expected_slice, expected_diff)
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
@@ -131,10 +131,6 @@ def torch_all_close(a, b, *args, **kwargs):
|
||||
|
||||
|
||||
def numpy_cosine_similarity_distance(a, b):
|
||||
if isinstance(a, torch.Tensor):
|
||||
a = a.detach().cpu().float().numpy()
|
||||
if isinstance(b, torch.Tensor):
|
||||
b = b.detach().cpu().float().numpy()
|
||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||
distance = 1.0 - similarity.mean()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user