mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-21 03:44:49 +08:00
Compare commits
1 Commits
qwenimage-
...
attention-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cc5eaf0429 |
@@ -11,8 +11,7 @@ export RUN_ATTENTION_BACKEND_TESTS=yes
|
|||||||
pytest tests/others/test_attention_backends.py
|
pytest tests/others/test_attention_backends.py
|
||||||
```
|
```
|
||||||
|
|
||||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
Tests were conducted on an H100 with PyTorch 2.9.1 (CUDA 12.9).
|
||||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
|
||||||
|
|
||||||
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
||||||
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
||||||
@@ -24,6 +23,8 @@ import os
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ..testing_utils import numpy_cosine_similarity_distance
|
||||||
|
|
||||||
|
|
||||||
pytestmark = pytest.mark.skipif(
|
pytestmark = pytest.mark.skipif(
|
||||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||||
@@ -36,23 +37,28 @@ from diffusers.utils import is_torch_version # noqa: E402
|
|||||||
FORWARD_CASES = [
|
FORWARD_CASES = [
|
||||||
(
|
(
|
||||||
"flash_hub",
|
"flash_hub",
|
||||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16),
|
||||||
|
1e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"_flash_3_hub",
|
"_flash_3_hub",
|
||||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||||
|
1e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"native",
|
"native",
|
||||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16),
|
||||||
),
|
1e-4
|
||||||
|
),
|
||||||
(
|
(
|
||||||
"_native_cudnn",
|
"_native_cudnn",
|
||||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||||
|
5e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"aiter",
|
"aiter",
|
||||||
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
||||||
|
1e-4
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -60,27 +66,32 @@ COMPILE_CASES = [
|
|||||||
(
|
(
|
||||||
"flash_hub",
|
"flash_hub",
|
||||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||||
True
|
True,
|
||||||
|
1e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"_flash_3_hub",
|
"_flash_3_hub",
|
||||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||||
True,
|
True,
|
||||||
|
1e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"native",
|
"native",
|
||||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||||
True,
|
True,
|
||||||
|
1e-4
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"_native_cudnn",
|
"_native_cudnn",
|
||||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||||
True,
|
True,
|
||||||
|
5e-4,
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
"aiter",
|
"aiter",
|
||||||
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
||||||
True,
|
True,
|
||||||
|
1e-4
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
@@ -104,11 +115,11 @@ def _backend_is_probably_supported(pipe, name: str):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _check_if_slices_match(output, expected_slice):
|
def _check_if_slices_match(output, expected_slice, expected_diff=1e-4):
|
||||||
img = output.images.detach().cpu()
|
img = output.images.detach().cpu()
|
||||||
generated_slice = img.flatten()
|
generated_slice = img.flatten()
|
||||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
assert numpy_cosine_similarity_distance(generated_slice, expected_slice) < expected_diff
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
@@ -126,23 +137,23 @@ def pipe(device):
|
|||||||
return pipe
|
return pipe
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
@pytest.mark.parametrize("backend_name,expected_slice,expected_diff", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||||
def test_forward(pipe, backend_name, expected_slice):
|
def test_forward(pipe, backend_name, expected_slice, expected_diff):
|
||||||
out = _backend_is_probably_supported(pipe, backend_name)
|
out = _backend_is_probably_supported(pipe, backend_name)
|
||||||
if isinstance(out, bool):
|
if isinstance(out, bool):
|
||||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||||
|
|
||||||
modified_pipe = out[0]
|
modified_pipe = out[0]
|
||||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||||
_check_if_slices_match(out, expected_slice)
|
_check_if_slices_match(out, expected_slice, expected_diff)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"backend_name,expected_slice,error_on_recompile",
|
"backend_name,expected_slice,error_on_recompile,expected_diff",
|
||||||
COMPILE_CASES,
|
COMPILE_CASES,
|
||||||
ids=[c[0] for c in COMPILE_CASES],
|
ids=[c[0] for c in COMPILE_CASES],
|
||||||
)
|
)
|
||||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile, expected_diff):
|
||||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||||
|
|
||||||
@@ -160,4 +171,4 @@ def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recom
|
|||||||
):
|
):
|
||||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||||
|
|
||||||
_check_if_slices_match(out, expected_slice)
|
_check_if_slices_match(out, expected_slice, expected_diff)
|
||||||
|
|||||||
@@ -131,6 +131,10 @@ def torch_all_close(a, b, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def numpy_cosine_similarity_distance(a, b):
|
def numpy_cosine_similarity_distance(a, b):
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
a = a.detach().cpu().float().numpy()
|
||||||
|
if isinstance(b, torch.Tensor):
|
||||||
|
b = b.detach().cpu().float().numpy()
|
||||||
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
similarity = np.dot(a, b) / (norm(a) * norm(b))
|
||||||
distance = 1.0 - similarity.mean()
|
distance = 1.0 - similarity.mean()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user