Compare commits

..

1 Commits

Author SHA1 Message Date
sayakpaul
5e87c38b29 remove non-hub attention backends. 2026-02-22 12:25:59 +05:30
5 changed files with 199 additions and 1003 deletions

View File

@@ -34,12 +34,7 @@ from ..utils import (
get_logger,
is_aiter_available,
is_aiter_version,
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
is_torch_version,
is_torch_xla_available,
@@ -55,62 +50,23 @@ from ._modeling_parallel import gather_size_by_comm
if TYPE_CHECKING:
from ._modeling_parallel import ParallelConfig
_REQUIRED_FLASH_VERSION = "2.6.3"
_REQUIRED_AITER_VERSION = "0.1.5"
_REQUIRED_SAGE_VERSION = "2.1.1"
_REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
_CAN_USE_NPU_ATTN = is_torch_npu_available()
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
if _CAN_USE_FLASH_ATTN:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
if _CAN_USE_FLASH_ATTN_3:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_varlen = None
if _CAN_USE_FLEX_ATTN:
# We cannot import the flex_attention function from the package directly because it is expected (from the
@@ -136,27 +92,6 @@ if _CAN_USE_XFORMERS_ATTN:
else:
xops = None
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
if torch.__version__ >= "2.4.0":
_custom_op = torch.library.custom_op
_register_fake = torch.library.register_fake
else:
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
def wrap(func):
return func
return wrap if fn is None else fn
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
def wrap(func):
return func
return wrap if fn is None else fn
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
@@ -304,11 +239,11 @@ def attention_backend(backend: str | AttentionBackendName = AttentionBackendName
"""
Context manager to set the active attention backend.
"""
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
if backend not in _AttentionBackendRegistry._backends:
raise ValueError(f"Backend {backend} is not registered.")
_maybe_download_kernel_for_backend(backend)
old_backend = _AttentionBackendRegistry._active_backend
@@ -442,16 +377,32 @@ def _check_shape(
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
if not _CAN_USE_FLASH_ATTN:
raise RuntimeError(
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
)
raise RuntimeError(
f"The '{backend.value}' attention backend has been removed. "
f"Please use 'flash_hub' or 'flash_varlen_hub' instead, which load the flash-attn kernel from the Hub. "
f"Install the required package with `pip install kernels`."
)
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
if not _CAN_USE_FLASH_ATTN_3:
raise RuntimeError(
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
raise RuntimeError(
f"The '{backend.value}' attention backend has been removed. "
f"Please use '_flash_3_hub' or '_flash_3_varlen_hub' instead, which load the flash-attn-3 kernel from the Hub. "
f"Install the required package with `pip install kernels`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
]:
raise RuntimeError(
f"The '{backend.value}' attention backend has been removed. "
f"Please use 'sage_hub' instead, which loads the SageAttention kernel from the Hub. "
f"Install the required package with `pip install kernels`."
)
elif backend in [
AttentionBackendName.FLASH_HUB,
@@ -471,19 +422,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
]:
if not _CAN_USE_SAGE_ATTN:
raise RuntimeError(
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
)
elif backend == AttentionBackendName.FLEX:
if not _CAN_USE_FLEX_ATTN:
raise RuntimeError(
@@ -652,78 +590,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
raise
# ===== torch op registrations =====
# Registrations are required for fullgraph tracing compatibility
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
def _wrapped_flash_attn_3(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
qv: torch.Tensor | None = None,
q_descale: torch.Tensor | None = None,
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
softmax_scale=softmax_scale,
causal=causal,
qv=qv,
q_descale=q_descale,
k_descale=k_descale,
v_descale=v_descale,
window_size=window_size,
attention_chunk=attention_chunk,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
)
lse = lse.permute(0, 2, 1)
return out, lse
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
def _(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale: float | None = None,
causal: bool = False,
qv: torch.Tensor | None = None,
q_descale: torch.Tensor | None = None,
k_descale: torch.Tensor | None = None,
v_descale: torch.Tensor | None = None,
attention_chunk: int = 0,
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
) -> tuple[torch.Tensor, torch.Tensor]:
window_size = (-1, -1) # noqa: F841
# A lot of the parameters here are not yet used in any way within diffusers.
# We can safely ignore for now and keep the fake op shape propagation simple.
batch_size, seq_len, num_heads, head_dim = q.shape
lse_shape = (batch_size, seq_len, num_heads)
return torch.empty_like(q), q.new_empty(lse_shape)
# ===== Helper functions to use attention backends with templated CP autograd functions =====
@@ -995,107 +861,6 @@ def _native_flash_attention_backward_op(
return grad_query, grad_key, grad_value
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
def _flash_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
# Hardcoded for now
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if scale is None:
scale = query.shape[-1] ** (-0.5)
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1)
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
lse_d = _wrapped_flash_attn_backward( # noqa: F841
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
# Head dimension may have been padded
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1327,44 +1092,6 @@ def _flash_attention_3_hub_backward_op(
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
out = sageattn(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1)
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -2205,59 +1932,6 @@ def _templated_context_parallel_attention(
# ===== Attention backends =====
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
if _parallel_config is None:
out = flash_attn_func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_forward_op,
backward_op=_flash_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -2369,88 +2043,6 @@ def _flash_varlen_attention_hub(
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out = flash_attn_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -2587,58 +2179,6 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_varlen_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
)
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.AITER,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -3108,57 +2648,6 @@ def _native_xla_attention(
return out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
if _parallel_config is None:
out = sageattn(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
@@ -3211,169 +2700,6 @@ def _sage_attention_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _sage_varlen_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if return_lse:
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
batch_size, seq_len_q, _, _ = query.shape
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
out = sageattn_varlen(
q=query_packed,
k=key_packed,
v=value_packed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
is_causal=is_causal,
sm_scale=scale,
)
out = out.unflatten(0, (batch_size, -1))
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
)
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
)
def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
is_causal: bool = False,
scale: float | None = None,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
@_AttentionBackendRegistry.register(
AttentionBackendName.XFORMERS,
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],

View File

@@ -1,4 +1,4 @@
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
from .attention import AttentionTesterMixin
from .cache import (
CacheTesterMixin,
FasterCacheConfigMixin,
@@ -38,7 +38,6 @@ from .training import TrainingTesterMixin
__all__ = [
"AttentionBackendTesterMixin",
"AttentionTesterMixin",
"BaseModelTesterConfig",
"BitsAndBytesCompileTesterMixin",

View File

@@ -14,105 +14,22 @@
# limitations under the License.
import gc
import logging
import pytest
import torch
from diffusers.models.attention import AttentionModuleMixin
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
from diffusers.models.attention_processor import AttnProcessor
from diffusers.utils import is_kernels_available, is_torch_version
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level backend parameter sets for AttentionBackendTesterMixin
# ---------------------------------------------------------------------------
_CUDA_AVAILABLE = torch.cuda.is_available()
_KERNELS_AVAILABLE = is_kernels_available()
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
_PARAM_NATIVE_CUDNN = pytest.param(
AttentionBackendName._NATIVE_CUDNN,
id="native_cudnn",
marks=pytest.mark.skipif(
not _CUDA_AVAILABLE,
reason="CUDA is required for _native_cudnn backend.",
),
from diffusers.models.attention_processor import (
AttnProcessor,
)
_PARAM_FLASH_HUB = pytest.param(
AttentionBackendName.FLASH_HUB,
id="flash_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
),
],
from ...testing_utils import (
assert_tensors_close,
backend_empty_cache,
is_attention,
torch_device,
)
_PARAM_FLASH_3_HUB = pytest.param(
AttentionBackendName._FLASH_3_HUB,
id="flash_3_hub",
marks=[
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
pytest.mark.skipif(
not _KERNELS_AVAILABLE,
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
),
],
)
# All backends under test.
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
_BF16_REQUIRED_BACKENDS = {
AttentionBackendName._NATIVE_CUDNN,
AttentionBackendName.FLASH_HUB,
AttentionBackendName._FLASH_3_HUB,
}
# Backends that perform non-deterministic operations and therefore cannot run when
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
def _maybe_cast_to_bf16(backend, model, inputs_dict):
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
if backend not in _BF16_REQUIRED_BACKENDS:
return model, inputs_dict
model = model.to(dtype=torch.bfloat16)
inputs_dict = {
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
for k, v in inputs_dict.items()
}
return model, inputs_dict
def _skip_if_backend_requires_nondeterminism(backend):
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
This check is intentionally deferred to test execution time because
enable_full_determinism() is typically called at module level in test files *after*
the module-level pytest.param() objects in this file have already been evaluated,
making it impossible to catch via a collection-time skipif condition.
"""
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
pytest.skip(
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
f"while `torch.use_deterministic_algorithms(True)` is active."
)
@is_attention
class AttentionTesterMixin:
@@ -122,6 +39,7 @@ class AttentionTesterMixin:
Tests functionality from AttentionModuleMixin including:
- Attention processor management (set/get)
- QKV projection fusion/unfusion
- Attention backends (XFormers, NPU, etc.)
Expected from config mixin:
- model_class: The model class to test
@@ -261,208 +179,3 @@ class AttentionTesterMixin:
model.set_attn_processor(wrong_processors)
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
@is_attention
class AttentionBackendTesterMixin:
"""
Mixin class for testing attention backends on models. Following things are tested:
1. Backends can be set with the `attention_backend` context manager and with
`set_attention_backend()` method.
2. SDPA outputs don't deviate too much from backend outputs.
3. Backend works with (regional) compilation.
4. Backends can be restored.
Tests the backends using the model provided by the host test class. The backends to test
are defined in `_ALL_BACKEND_PARAMS`.
Expected from the host test class:
- model_class: The model class to instantiate.
Expected methods from the host test class:
- get_init_dict(): Returns dict of kwargs to construct the model.
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
Pytest mark: attention
Use `pytest -m "not attention"` to skip these tests.
"""
# -----------------------------------------------------------------------
# Tolerance attributes — override in host class to loosen/tighten checks.
# -----------------------------------------------------------------------
# test_output_close_to_native: alternate backends (flash, cuDNN) may
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
backend_vs_native_atol: float = 1e-2
backend_vs_native_rtol: float = 1e-2
# test_compile: regional compilation introduces the same kind of numerical
# error as the non-compiled backend path, so the same loose tolerance applies.
compile_vs_native_atol: float = 1e-2
compile_vs_native_rtol: float = 1e-2
def setup_method(self):
gc.collect()
backend_empty_cache(torch_device)
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_set_attention_backend_matches_context_manager(self, backend):
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(backend):
ctx_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
set_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
set_output,
ctx_output,
atol=0,
rtol=0,
msg=(
f"Output from model.set_attention_backend('{backend.value}') should be identical "
f"to the output from `with attention_backend('{backend.value}'):`."
),
)
@torch.no_grad()
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_output_close_to_native(self, backend):
"""All backends should produce model output numerically close to the native SDPA reference."""
_skip_if_backend_requires_nondeterminism(backend)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
backend_output = model(**inputs_dict, return_dict=False)[0]
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
backend_output,
native_output,
atol=self.backend_vs_native_atol,
rtol=self.backend_vs_native_rtol,
msg=f"Output from {backend} should be numerically close to native SDPA.",
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_context_manager_switches_and_restores_backend(self, backend):
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
with attention_backend(backend):
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert active_backend == backend, (
f"Backend should be {backend} inside the context manager, got {active_backend}."
)
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
assert restored_backend == initial_backend, (
f"Backend should be restored to {initial_backend} after exiting the context manager, "
f"got {restored_backend}."
)
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
def test_compile(self, backend):
"""
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
as opposed to `model.compile`).
"""
_skip_if_backend_requires_nondeterminism(backend)
if getattr(self.model_class, "_repeated_blocks", None) is None:
pytest.skip("Skipping tests as regional compilation is not supported.")
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
pytest.xfail(
"test_compile with the native backend requires torch >= 2.9.0 for stable "
"fullgraph compilation with error_on_recompile=True."
)
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
native_output = model(**inputs_dict, return_dict=False)[0]
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
try:
model.set_attention_backend(backend.value)
except Exception as e:
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
pytest.skip(str(e))
try:
model.compile_repeated_blocks(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=True),
):
with torch.no_grad():
compile_output = model(**inputs_dict, return_dict=False)[0]
model(**inputs_dict, return_dict=False)
finally:
model.reset_attention_backend()
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
assert_tensors_close(
compile_output,
native_output,
atol=self.compile_vs_native_atol,
rtol=self.compile_vs_native_rtol,
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
)

View File

@@ -25,7 +25,6 @@ from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionBackendTesterMixin,
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesCompileTesterMixin,
@@ -225,10 +224,6 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
"""Attention processor tests for Flux Transformer."""
class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
"""Attention backend tests for Flux Transformer."""
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux Transformer"""

View File

@@ -0,0 +1,163 @@
"""
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
Once attention backends become more mature, we can consider including this in our CI.
To run this test suite:
```bash
export RUN_ATTENTION_BACKEND_TESTS=yes
pytest tests/others/test_attention_backends.py
```
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
aiter 0.1.5.post4.dev20+ga25e55e79.
"""
import os
import pytest
import torch
pytestmark = pytest.mark.skipif(
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
)
from diffusers import FluxPipeline # noqa: E402
from diffusers.utils import is_torch_version # noqa: E402
# fmt: off
FORWARD_CASES = [
(
"flash_hub",
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
),
(
"_flash_3_hub",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
),
(
"native",
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
),
(
"_native_cudnn",
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
),
(
"aiter",
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
)
]
COMPILE_CASES = [
(
"flash_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True
),
(
"_flash_3_hub",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
True,
),
(
"native",
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
True,
),
(
"_native_cudnn",
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
True,
),
(
"aiter",
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
True,
)
]
# fmt: on
INFER_KW = {
"prompt": "dance doggo dance",
"height": 256,
"width": 256,
"num_inference_steps": 2,
"guidance_scale": 3.5,
"max_sequence_length": 128,
"output_type": "pt",
}
def _backend_is_probably_supported(pipe, name: str):
try:
pipe.transformer.set_attention_backend(name)
return pipe, True
except Exception:
return False
def _check_if_slices_match(output, expected_slice):
img = output.images.detach().cpu()
generated_slice = img.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
@pytest.fixture(scope="session")
def device():
if not torch.cuda.is_available():
pytest.skip("CUDA is required for these tests.")
return torch.device("cuda:0")
@pytest.fixture(scope="session")
def pipe(device):
repo_id = "black-forest-labs/FLUX.1-dev"
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
pipe.set_progress_bar_config(disable=True)
return pipe
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
def test_forward(pipe, backend_name, expected_slice):
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)
@pytest.mark.parametrize(
"backend_name,expected_slice,error_on_recompile",
COMPILE_CASES,
ids=[c[0] for c in COMPILE_CASES],
)
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
out = _backend_is_probably_supported(pipe, backend_name)
if isinstance(out, bool):
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
modified_pipe = out[0]
modified_pipe.transformer.compile(fullgraph=True)
torch.compiler.reset()
with (
torch._inductor.utils.fresh_inductor_cache(),
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
):
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
_check_if_slices_match(out, expected_slice)