Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
cc5eaf0429 enhance attention backend tests 2025-12-12 13:25:34 +05:30
2 changed files with 29 additions and 14 deletions

View File

@@ -11,8 +11,7 @@ export RUN_ATTENTION_BACKEND_TESTS=yes
pytest tests/others/test_attention_backends.py pytest tests/others/test_attention_backends.py
``` ```
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
@@ -24,6 +23,8 @@ import os
import pytest import pytest
import torch import torch
from ..testing_utils import numpy_cosine_similarity_distance
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough." os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
@@ -36,23 +37,28 @@ from diffusers.utils import is_torch_version # noqa: E402
FORWARD_CASES = [ FORWARD_CASES = [
( (
"flash_hub", "flash_hub",
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16) torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
1e-4
), ),
( (
"_flash_3_hub", "_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16), torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
1e-4
), ),
( (
"native", "native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16) torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
), 1e-4
),
( (
"_native_cudnn", "_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16), torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
5e-4
), ),
( (
"aiter", "aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16), torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
1e-4
) )
] ]
@@ -60,27 +66,32 @@ COMPILE_CASES = [
( (
"flash_hub", "flash_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True True,
1e-4
), ),
( (
"_flash_3_hub", "_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16), torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True, True,
1e-4
), ),
( (
"native", "native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16), torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True, True,
1e-4
), ),
( (
"_native_cudnn", "_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16), torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True, True,
5e-4,
), ),
( (
"aiter", "aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16), torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True, True,
1e-4
) )
] ]
# fmt: on # fmt: on
@@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
return False return False
def _check_if_slices_match(output, expected_slice): def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
img = output.images.detach().cpu() img = output.images.detach().cpu()
generated_slice = img.flatten() generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]]) generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4) assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@@ -126,23 +137,23 @@ def pipe(device):
return pipe return pipe
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES]) @pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice): def test_forward(pipe, backend_name, expected_slice, expected_diff):
out = _backend_is_probably_supported(pipe, backend_name) out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool): if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.") pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0] modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice) _check_if_slices_match(out, expected_slice, expected_diff)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile", "backend_name,expected_slice,error_on_recompile,expected_diff",
COMPILE_CASES, COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES], ids=[c[0] for c in COMPILE_CASES],
) )
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile): def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"): if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.") pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
@@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
): ):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0)) out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice) _check_if_slices_match(out, expected_slice, expected_diff)

View File

@@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs):
def numpy_cosine_similarity_distance(a, b): def numpy_cosine_similarity_distance(a, b):
if isinstance(a, torch.Tensor):
a = a.detach().cpu().float().numpy()
if isinstance(b, torch.Tensor):
b = b.detach().cpu().float().numpy()
similarity = np.dot(a, b) / (norm(a) * norm(b)) similarity = np.dot(a, b) / (norm(a) * norm(b))
distance = 1.0 - similarity.mean() distance = 1.0 - similarity.mean()