Compare commits

..

2 Commits

Author SHA1 Message Date
Sayak Paul
022ac4ddf6 Fix torchrun command argument order in docs 2026-02-24 16:10:34 +05:30
SYM.BOT
1f6ac1c3d1 fix: graceful fallback when attention backends fail to import (#13060)
* fix: graceful fallback when attention backends fail to import

## Problem

External attention backends (flash_attn, xformers, sageattention, etc.) may be
installed but fail to import at runtime due to ABI mismatches. For example,
when `flash_attn` is compiled against PyTorch 2.4 but used with PyTorch 2.8,
the import fails with:

```
OSError: .../flash_attn_2_cuda.cpython-311-x86_64-linux-gnu.so: undefined symbol: _ZN3c104cuda9SetDeviceEab
```

The current code uses `importlib.util.find_spec()` to check if packages exist,
but this only verifies the package is installed—not that it can actually be
imported. When the import fails, diffusers crashes instead of falling back to
native PyTorch attention.

## Solution

Wrap all external attention backend imports in try-except blocks that catch
`ImportError` and `OSError`. On failure:
1. Log a warning message explaining the issue
2. Set the corresponding `_CAN_USE_*` flag to `False`
3. Set the imported functions to `None`

This allows diffusers to gracefully degrade to PyTorch's native SDPA
(scaled_dot_product_attention) instead of crashing.

## Affected backends

- flash_attn (Flash Attention)
- flash_attn_3 (Flash Attention 3)
- aiter (AMD Instinct)
- sageattention (SageAttention)
- flex_attention (PyTorch Flex Attention)
- torch_npu (Huawei NPU)
- torch_xla (TPU/XLA)
- xformers (Meta xFormers)

## Testing

Tested with PyTorch 2.8.0 and flash_attn 2.7.4.post1 (compiled for PyTorch 2.4).
Before: crashes on import. After: logs warning and uses native attention.

* address review: use single logger and catch RuntimeError

- Move logger to module level instead of creating per-backend loggers
- Add RuntimeError to exception list alongside ImportError and OSError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* Apply style fixes

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-02-24 13:37:39 +05:30
7 changed files with 248 additions and 327 deletions

View File

@@ -111,7 +111,7 @@ if __name__ == "__main__":
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
```bash
torchrun run_distributed.py --nproc_per_node=2
torchrun --nproc_per_node=2 run_distributed.py
```
## device_map

View File

@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError, RuntimeError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
_CAN_USE_FLASH_ATTN = False
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
else:
flash_attn_func = None
flash_attn_varlen_func = None
@@ -83,26 +95,47 @@ else:
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_3 = False
flash_attn_3_func = None
flash_attn_3_varlen_func = None
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
try:
from aiter import flash_attn_func as aiter_flash_attn_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
_CAN_USE_AITER_ATTN = False
aiter_flash_attn_func = None
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
try:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
_CAN_USE_SAGE_ATTN = False
sageattn = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_varlen = None
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
@@ -113,26 +146,48 @@ else:
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
try:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLEX_ATTN = False
flex_attention = None
else:
flex_attention = None
if _CAN_USE_NPU_ATTN:
from torch_npu import npu_fusion_attention
try:
from torch_npu import npu_fusion_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
_CAN_USE_NPU_ATTN = False
npu_fusion_attention = None
else:
npu_fusion_attention = None
if _CAN_USE_XLA_ATTN:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
try:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
_CAN_USE_XLA_ATTN = False
xla_flash_attention = None
else:
xla_flash_attention = None
if _CAN_USE_XFORMERS_ATTN:
import xformers.ops as xops
try:
import xformers.ops as xops
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
_CAN_USE_XFORMERS_ATTN = False
xops = None
else:
xops = None
@@ -158,8 +213,6 @@ else:
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods

View File

@@ -1,4 +1,4 @@
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
@@ -38,7 +38,6 @@ from .training import TrainingTesterMixin
__all__ = [
"AttentionBackendTesterMixin",
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",

View File

@@ -14,105 +14,22 @@
# limitations under the License.
import gc
import logging
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import is_kernels_available, is_torch_version
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level backend parameter sets for AttentionBackendTesterMixin
# ---------------------------------------------------------------------------
_CUDA_AVAILABLE = torch.cuda.is_available()
_KERNELS_AVAILABLE = is_kernels_available()
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
_PARAM_NATIVE_CUDNN = pytest.param(
AttentionBackendName._NATIVE_CUDNN,
id="native_cudnn",
marks=pytest.mark.skipif(
not _CUDA_AVAILABLE,
reason="CUDA is required for _native_cudnn backend.",
),
from diffusers.models.attention_processor import (
AttnProcessor,
)
_PARAM_FLASH_HUB = pytest.param(
AttentionBackendName.FLASH_HUB,
id="flash_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
),
],
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
_PARAM_FLASH_3_HUB = pytest.param(
AttentionBackendName._FLASH_3_HUB,
id="flash_3_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
),
],
)
# All backends under test.
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}
# Backends that perform non-deterministic operations and therefore cannot run when
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict
def _skip_if_backend_requires_nondeterminism(backend):
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
This check is intentionally deferred to test execution time because
enable_full_determinism() is typically called at module level in test files *after*
the module-level pytest.param() objects in this file have already been evaluated,
making it impossible to catch via a collection-time skipif condition.
"""
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
pytest.skip(
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
f"while `torch.use_deterministic_algorithms(True)` is active."
)
@is_attention
class AttentionTesterMixin:
@@ -122,6 +39,7 @@ class AttentionTesterMixin:
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected from config mixin:
- model_class: The model class to test
@@ -261,208 +179,3 @@ class AttentionTesterMixin:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
@is_attention
class AttentionBackendTesterMixin:
"""
Mixin class for testing attention backends on models. Following things are tested:
1. Backends can be set with the `attention_backend` context manager and with
`set_attention_backend()` method.
2. SDPA outputs don't deviate too much from backend outputs.
3. Backend works with (regional) compilation.
4. Backends can be restored.
Tests the backends using the model provided by the host test class. The backends to test
are defined in `_ALL_BACKEND_PARAMS`.
Expected from the host test class:
- model_class: The model class to instantiate.
Expected methods from the host test class:
- get_init_dict(): Returns dict of kwargs to construct the model.
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests.
"""
# -----------------------------------------------------------------------
# Tolerance attributes — override in host class to loosen/tighten checks.
# -----------------------------------------------------------------------
# test_output_close_to_native: alternate backends (flash, cuDNN) may
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
backend_vs_native_atol: float = 1e-2
backend_vs_native_rtol: float = 1e-2
# test_compile: regional compilation introduces the same kind of numerical
# error as the non-compiled backend path, so the same loose tolerance applies.
compile_vs_native_atol: float = 1e-2
compile_vs_native_rtol: float = 1e-2
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_set_attention_backend_matches_context_manager(self, backend):
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(backend):
ctx_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
set_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
set_output,
ctx_output,
atol=0,
rtol=0,
msg=(
f"Output from model.set_attention_backend('{backend.value}') should be identical "
f"to the output from `with attention_backend('{backend.value}'):`."
),
)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_output_close_to_native(self, backend):
"""All backends should produce model output numerically close to the native SDPA reference."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
backend_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
backend_output,
native_output,
atol=self.backend_vs_native_atol,
rtol=self.backend_vs_native_rtol,
msg=f"Output from {backend} should be numerically close to native SDPA.",
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_context_manager_switches_and_restores_backend(self, backend):
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
with attention_backend(backend):
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert active_backend == backend, (
f"Backend should be {backend} inside the context manager, got {active_backend}."
)
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert restored_backend == initial_backend, (
f"Backend should be restored to {initial_backend} after exiting the context manager, "
f"got {restored_backend}."
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_compile(self, backend):
"""
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
as opposed to `model.compile`).
"""
_skip_if_backend_requires_nondeterminism(backend)
if getattr(self.model_class, "_repeated_blocks", None) is None:
pytest.skip("Skipping tests as regional compilation is not supported.")
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
pytest.xfail(
"test_compile with the native backend requires torch >= 2.9.0 for stable "
"fullgraph compilation with error_on_recompile=True."
)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
model.compile_repeated_blocks(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
with torch.no_grad():
compile_output = model(**inputs_dict, return_dict=False)[0]
model(**inputs_dict, return_dict=False)
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
compile_output,
native_output,
atol=self.compile_vs_native_atol,
rtol=self.compile_vs_native_rtol,
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
)

View File

@@ -25,7 +25,6 @@ from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionBackendTesterMixin,
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
@@ -225,10 +224,6 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
"""Attention processor tests for Flux Transformer."""
class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
"""Attention backend tests for Flux Transformer."""
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""

View File

@@ -0,0 +1,163 @@
"""
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
Once attention backends become more mature, we can consider including this in our CI.
To run this test suite:
```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
pytest tests/others/test_attention_backends.py
```
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
"""
import os
import pytest
import torch
pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
)
from diffusers import FluxPipeline # noqa: E402
from diffusers.utils import is_torch_version # noqa: E402
# fmt: off
FORWARD_CASES = [
(
"flash_hub",
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
),
(
"_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
),
(
"native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
),
(
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
]
COMPILE_CASES = [
(
"flash_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True
),
(
"_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True,
),
(
"native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True,
),
(
"_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
]
# fmt: on
INFER_KW = {
"prompt": "dance doggo dance",
"height": 256,
"width": 256,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"max_sequence_length": 128,
"output_type": "pt",
}
def _backend_is_probably_supported(pipe, name: str):
try:
pipe.transformer.set_attention_backend(name)
return pipe, True
except Exception:
return False
def _check_if_slices_match(output, expected_slice):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
@pytest.fixture(scope="session")
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA is required for these tests.")
return torch.device("cuda:0")
@pytest.fixture(scope="session")
def pipe(device):
repo_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
pipe.set_progress_bar_config(disable=True)
return pipe
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice):
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
@pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile",
COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES],
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
modified_pipe.transformer.compile(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)

View File

@@ -72,7 +72,6 @@ OPTIONAL_TESTERS = [
# Other testers
("SingleFileTesterMixin", "single_file"),
("IPAdapterTesterMixin", "ip_adapter"),
("AttentionBackendTesterMixin", "attention_backends"),
]
@@ -531,7 +530,6 @@ def main():
"faster_cache",
"single_file",
"ip_adapter",
"attention_backends",
"all",
],
help="Optional testers to include",