mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-23 19:30:38 +08:00
Compare commits
1 Commits
attn-backe
...
remove-non
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e87c38b29 |
@@ -34,12 +34,7 @@ from ..utils import (
|
||||
get_logger,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_kernels_available,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
@@ -55,62 +50,23 @@ from ._modeling_parallel import gather_size_by_comm
|
||||
if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_AITER_VERSION = "0.1.5"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
||||
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
||||
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
||||
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
||||
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_varlen = None
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
@@ -136,27 +92,6 @@ if _CAN_USE_XFORMERS_ATTN:
|
||||
else:
|
||||
xops = None
|
||||
|
||||
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
|
||||
if torch.__version__ >= "2.4.0":
|
||||
_custom_op = torch.library.custom_op
|
||||
_register_fake = torch.library.register_fake
|
||||
else:
|
||||
|
||||
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
_custom_op = custom_op_no_op
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -304,11 +239,11 @@ def attention_backend(backend: str | AttentionBackendName = AttentionBackendName
|
||||
"""
|
||||
Context manager to set the active attention backend.
|
||||
"""
|
||||
if backend not in _AttentionBackendRegistry._backends:
|
||||
raise ValueError(f"Backend {backend} is not registered.")
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
|
||||
if backend not in _AttentionBackendRegistry._backends:
|
||||
raise ValueError(f"Backend {backend} is not registered.")
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
old_backend = _AttentionBackendRegistry._active_backend
|
||||
@@ -442,16 +377,32 @@ def _check_shape(
|
||||
|
||||
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
||||
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
||||
if not _CAN_USE_FLASH_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use 'flash_hub' or 'flash_varlen_hub' instead, which load the flash-attn kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
||||
if not _CAN_USE_FLASH_ATTN_3:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use '_flash_3_hub' or '_flash_3_varlen_hub' instead, which load the flash-attn-3 kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
]:
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use 'sage_hub' instead, which loads the SageAttention kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
@@ -471,19 +422,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
]:
|
||||
if not _CAN_USE_SAGE_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.FLEX:
|
||||
if not _CAN_USE_FLEX_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -652,78 +590,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
raise
|
||||
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
||||
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
qv: torch.Tensor | None = None,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
qv=qv,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
|
||||
|
||||
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
|
||||
def _(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
qv: torch.Tensor | None = None,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
window_size = (-1, -1) # noqa: F841
|
||||
# A lot of the parameters here are not yet used in any way within diffusers.
|
||||
# We can safely ignore for now and keep the fake op shape propagation simple.
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
return torch.empty_like(q), q.new_empty(lse_shape)
|
||||
|
||||
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
@@ -995,107 +861,6 @@ def _native_flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
|
||||
def _flash_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
|
||||
|
||||
# Hardcoded for now
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
lse_d = _wrapped_flash_attn_backward( # noqa: F841
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
# Head dimension may have been padded
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1327,44 +1092,6 @@ def _flash_attention_3_hub_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
out = sageattn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -2205,59 +1932,6 @@ def _templated_context_parallel_attention(
|
||||
# ===== Attention backends =====
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_forward_op,
|
||||
backward_op=_flash_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -2369,88 +2043,6 @@ def _flash_varlen_attention_hub(
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out = flash_attn_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -2587,58 +2179,6 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_varlen_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out, lse, *_ = flash_attn_3_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.AITER,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -3108,57 +2648,6 @@ def _native_xla_attention(
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -3211,169 +2700,6 @@ def _sage_attention_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _sage_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
|
||||
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out = sageattn_varlen(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.XFORMERS,
|
||||
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from .attention import AttentionBackendTesterMixin, AttentionTesterMixin
|
||||
from .attention import AttentionTesterMixin
|
||||
from .cache import (
|
||||
CacheTesterMixin,
|
||||
FasterCacheConfigMixin,
|
||||
@@ -38,7 +38,6 @@ from .training import TrainingTesterMixin
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AttentionBackendTesterMixin",
|
||||
"AttentionTesterMixin",
|
||||
"BaseModelTesterConfig",
|
||||
"BitsAndBytesCompileTesterMixin",
|
||||
|
||||
@@ -14,105 +14,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from diffusers.models.attention import AttentionModuleMixin
|
||||
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry, attention_backend
|
||||
from diffusers.models.attention_processor import AttnProcessor
|
||||
from diffusers.utils import is_kernels_available, is_torch_version
|
||||
|
||||
from ...testing_utils import assert_tensors_close, backend_empty_cache, is_attention, torch_device
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Module-level backend parameter sets for AttentionBackendTesterMixin
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_CUDA_AVAILABLE = torch.cuda.is_available()
|
||||
_KERNELS_AVAILABLE = is_kernels_available()
|
||||
|
||||
_PARAM_NATIVE = pytest.param(AttentionBackendName.NATIVE, id="native")
|
||||
|
||||
_PARAM_NATIVE_CUDNN = pytest.param(
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
id="native_cudnn",
|
||||
marks=pytest.mark.skipif(
|
||||
not _CUDA_AVAILABLE,
|
||||
reason="CUDA is required for _native_cudnn backend.",
|
||||
),
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor,
|
||||
)
|
||||
|
||||
_PARAM_FLASH_HUB = pytest.param(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
id="flash_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for flash_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for flash_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
from ...testing_utils import (
|
||||
assert_tensors_close,
|
||||
backend_empty_cache,
|
||||
is_attention,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
_PARAM_FLASH_3_HUB = pytest.param(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
id="flash_3_hub",
|
||||
marks=[
|
||||
pytest.mark.skipif(not _CUDA_AVAILABLE, reason="CUDA is required for _flash_3_hub backend."),
|
||||
pytest.mark.skipif(
|
||||
not _KERNELS_AVAILABLE,
|
||||
reason="`kernels` package is required for _flash_3_hub backend. Install with `pip install kernels`.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# All backends under test.
|
||||
_ALL_BACKEND_PARAMS = [_PARAM_NATIVE, _PARAM_NATIVE_CUDNN, _PARAM_FLASH_HUB, _PARAM_FLASH_3_HUB]
|
||||
|
||||
# Backends that only accept bf16/fp16 inputs; models and inputs must be cast before running them.
|
||||
_BF16_REQUIRED_BACKENDS = {
|
||||
AttentionBackendName._NATIVE_CUDNN,
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
}
|
||||
|
||||
# Backends that perform non-deterministic operations and therefore cannot run when
|
||||
# torch.use_deterministic_algorithms(True) is active (e.g. after enable_full_determinism()).
|
||||
_NON_DETERMINISTIC_BACKENDS = {AttentionBackendName._NATIVE_CUDNN}
|
||||
|
||||
|
||||
def _maybe_cast_to_bf16(backend, model, inputs_dict):
|
||||
"""Cast model and floating-point inputs to bfloat16 when the backend requires it."""
|
||||
if backend not in _BF16_REQUIRED_BACKENDS:
|
||||
return model, inputs_dict
|
||||
model = model.to(dtype=torch.bfloat16)
|
||||
inputs_dict = {
|
||||
k: v.to(dtype=torch.bfloat16) if isinstance(v, torch.Tensor) and v.is_floating_point() else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
return model, inputs_dict
|
||||
|
||||
|
||||
def _skip_if_backend_requires_nondeterminism(backend):
|
||||
"""Skip at runtime when torch.use_deterministic_algorithms(True) blocks the backend.
|
||||
|
||||
This check is intentionally deferred to test execution time because
|
||||
enable_full_determinism() is typically called at module level in test files *after*
|
||||
the module-level pytest.param() objects in this file have already been evaluated,
|
||||
making it impossible to catch via a collection-time skipif condition.
|
||||
"""
|
||||
if backend in _NON_DETERMINISTIC_BACKENDS and torch.are_deterministic_algorithms_enabled():
|
||||
pytest.skip(
|
||||
f"Backend '{backend.value}' performs non-deterministic operations and cannot run "
|
||||
f"while `torch.use_deterministic_algorithms(True)` is active."
|
||||
)
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionTesterMixin:
|
||||
@@ -122,6 +39,7 @@ class AttentionTesterMixin:
|
||||
Tests functionality from AttentionModuleMixin including:
|
||||
- Attention processor management (set/get)
|
||||
- QKV projection fusion/unfusion
|
||||
- Attention backends (XFormers, NPU, etc.)
|
||||
|
||||
Expected from config mixin:
|
||||
- model_class: The model class to test
|
||||
@@ -261,208 +179,3 @@ class AttentionTesterMixin:
|
||||
model.set_attn_processor(wrong_processors)
|
||||
|
||||
assert "number of processors" in str(exc_info.value).lower(), "Error should mention processor count mismatch"
|
||||
|
||||
|
||||
@is_attention
|
||||
class AttentionBackendTesterMixin:
|
||||
"""
|
||||
Mixin class for testing attention backends on models. Following things are tested:
|
||||
|
||||
1. Backends can be set with the `attention_backend` context manager and with
|
||||
`set_attention_backend()` method.
|
||||
2. SDPA outputs don't deviate too much from backend outputs.
|
||||
3. Backend works with (regional) compilation.
|
||||
4. Backends can be restored.
|
||||
|
||||
Tests the backends using the model provided by the host test class. The backends to test
|
||||
are defined in `_ALL_BACKEND_PARAMS`.
|
||||
|
||||
Expected from the host test class:
|
||||
- model_class: The model class to instantiate.
|
||||
|
||||
Expected methods from the host test class:
|
||||
- get_init_dict(): Returns dict of kwargs to construct the model.
|
||||
- get_dummy_inputs(): Returns dict of inputs for the model's forward pass.
|
||||
|
||||
Pytest mark: attention
|
||||
Use `pytest -m "not attention"` to skip these tests.
|
||||
"""
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Tolerance attributes — override in host class to loosen/tighten checks.
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
# test_output_close_to_native: alternate backends (flash, cuDNN) may
|
||||
# accumulate small numerical errors vs the reference PyTorch SDPA kernel.
|
||||
backend_vs_native_atol: float = 1e-2
|
||||
backend_vs_native_rtol: float = 1e-2
|
||||
|
||||
# test_compile: regional compilation introduces the same kind of numerical
|
||||
# error as the non-compiled backend path, so the same loose tolerance applies.
|
||||
compile_vs_native_atol: float = 1e-2
|
||||
compile_vs_native_rtol: float = 1e-2
|
||||
|
||||
def setup_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
def teardown_method(self):
|
||||
gc.collect()
|
||||
backend_empty_cache(torch_device)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_set_attention_backend_matches_context_manager(self, backend):
|
||||
"""set_attention_backend() and the attention_backend() context manager must yield identical outputs."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(backend):
|
||||
ctx_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
set_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
set_output,
|
||||
ctx_output,
|
||||
atol=0,
|
||||
rtol=0,
|
||||
msg=(
|
||||
f"Output from model.set_attention_backend('{backend.value}') should be identical "
|
||||
f"to the output from `with attention_backend('{backend.value}'):`."
|
||||
),
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_output_close_to_native(self, backend):
|
||||
"""All backends should produce model output numerically close to the native SDPA reference."""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
backend_output = model(**inputs_dict, return_dict=False)[0]
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
backend_output,
|
||||
native_output,
|
||||
atol=self.backend_vs_native_atol,
|
||||
rtol=self.backend_vs_native_rtol,
|
||||
msg=f"Output from {backend} should be numerically close to native SDPA.",
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_context_manager_switches_and_restores_backend(self, backend):
|
||||
"""attention_backend() should activate the requested backend and restore the previous one on exit."""
|
||||
initial_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
with attention_backend(backend):
|
||||
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert active_backend == backend, (
|
||||
f"Backend should be {backend} inside the context manager, got {active_backend}."
|
||||
)
|
||||
|
||||
restored_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
assert restored_backend == initial_backend, (
|
||||
f"Backend should be restored to {initial_backend} after exiting the context manager, "
|
||||
f"got {restored_backend}."
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("backend", _ALL_BACKEND_PARAMS)
|
||||
def test_compile(self, backend):
|
||||
"""
|
||||
`torch.compile` tests checking for recompilation, graph breaks, forward can run, etc.
|
||||
For speed, we use regional compilation here (`model.compile_repeated_blocks()`
|
||||
as opposed to `model.compile`).
|
||||
"""
|
||||
_skip_if_backend_requires_nondeterminism(backend)
|
||||
if getattr(self.model_class, "_repeated_blocks", None) is None:
|
||||
pytest.skip("Skipping tests as regional compilation is not supported.")
|
||||
|
||||
if backend == AttentionBackendName.NATIVE and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(
|
||||
"test_compile with the native backend requires torch >= 2.9.0 for stable "
|
||||
"fullgraph compilation with error_on_recompile=True."
|
||||
)
|
||||
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model, inputs_dict = _maybe_cast_to_bf16(backend, model, inputs_dict)
|
||||
|
||||
with torch.no_grad(), attention_backend(AttentionBackendName.NATIVE):
|
||||
native_output = model(**inputs_dict, return_dict=False)[0]
|
||||
|
||||
initial_registry_backend, _ = _AttentionBackendRegistry.get_active_backend()
|
||||
|
||||
try:
|
||||
model.set_attention_backend(backend.value)
|
||||
except Exception as e:
|
||||
logger.warning("Skipping test for backend '%s': %s", backend.value, e)
|
||||
pytest.skip(str(e))
|
||||
|
||||
try:
|
||||
model.compile_repeated_blocks(fullgraph=True)
|
||||
torch.compiler.reset()
|
||||
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
):
|
||||
with torch.no_grad():
|
||||
compile_output = model(**inputs_dict, return_dict=False)[0]
|
||||
model(**inputs_dict, return_dict=False)
|
||||
finally:
|
||||
model.reset_attention_backend()
|
||||
_AttentionBackendRegistry.set_active_backend(initial_registry_backend)
|
||||
|
||||
assert_tensors_close(
|
||||
compile_output,
|
||||
native_output,
|
||||
atol=self.compile_vs_native_atol,
|
||||
rtol=self.compile_vs_native_rtol,
|
||||
msg=f"Compiled output with backend '{backend.value}' should be numerically close to eager native SDPA.",
|
||||
)
|
||||
|
||||
@@ -25,7 +25,6 @@ from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
AttentionBackendTesterMixin,
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesCompileTesterMixin,
|
||||
@@ -225,10 +224,6 @@ class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterM
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerAttentionBackend(FluxTransformerTesterConfig, AttentionBackendTesterMixin):
|
||||
"""Attention backend tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerContextParallel(FluxTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux Transformer"""
|
||||
|
||||
|
||||
163
tests/others/test_attention_backends.py
Normal file
163
tests/others/test_attention_backends.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
This test suite exists for the maintainers currently. It's not run in our CI at the moment.
|
||||
|
||||
Once attention backends become more mature, we can consider including this in our CI.
|
||||
|
||||
To run this test suite:
|
||||
|
||||
```bash
|
||||
export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Tests were conducted on an H100 with PyTorch 2.8.0 (CUDA 12.9). Slices for the compilation tests in
|
||||
"native" variants were obtained with a torch nightly version (2.10.0.dev20250924+cu128).
|
||||
|
||||
Tests for aiter backend were conducted and slices for the aiter backend tests collected on a MI355X
|
||||
with torch 2025-09-25 nightly version (ad2f7315ca66b42497047bb7951f696b50f1e81b) and
|
||||
aiter 0.1.5.post4.dev20+ga25e55e79.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.getenv("RUN_ATTENTION_BACKEND_TESTS", "false") == "false", reason="Feature not mature enough."
|
||||
)
|
||||
from diffusers import FluxPipeline # noqa: E402
|
||||
from diffusers.utils import is_torch_version # noqa: E402
|
||||
|
||||
|
||||
# fmt: off
|
||||
FORWARD_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0918, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2461, 0.2422, 0.2695], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0977, 0.0996, 0.1016, 0.1016, 0.2188, 0.2246, 0.2344, 0.2480, 0.2539, 0.2480, 0.2441, 0.2715], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0820, 0.0859, 0.0938, 0.1016, 0.0957, 0.0996, 0.0996, 0.1016, 0.2188, 0.2266, 0.2363, 0.2500, 0.2539, 0.2480, 0.2461, 0.2734], dtype=torch.bfloat16)
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0781, 0.0840, 0.0879, 0.0957, 0.0898, 0.0957, 0.0957, 0.0977, 0.2168, 0.2246, 0.2324, 0.2500, 0.2539, 0.2480, 0.2441, 0.2695], dtype=torch.bfloat16),
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0781, 0.0820, 0.0879, 0.0957, 0.0898, 0.0938, 0.0957, 0.0957, 0.2285, 0.2363, 0.2461, 0.2637, 0.2695, 0.2617, 0.2617, 0.2891], dtype=torch.bfloat16),
|
||||
)
|
||||
]
|
||||
|
||||
COMPILE_CASES = [
|
||||
(
|
||||
"flash_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2324, 0.2422, 0.2539, 0.2734, 0.2832, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True
|
||||
),
|
||||
(
|
||||
"_flash_3_hub",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0625, 0.0605, 0.2344, 0.2461, 0.2578, 0.2734, 0.2852, 0.2812, 0.2773, 0.3047], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"native",
|
||||
torch.tensor([0.0410, 0.0410, 0.0449, 0.0508, 0.0508, 0.0605, 0.0605, 0.0605, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2773, 0.3066], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"_native_cudnn",
|
||||
torch.tensor([0.0410, 0.0410, 0.0430, 0.0508, 0.0488, 0.0586, 0.0605, 0.0586, 0.2344, 0.2461, 0.2578, 0.2773, 0.2871, 0.2832, 0.2793, 0.3086], dtype=torch.bfloat16),
|
||||
True,
|
||||
),
|
||||
(
|
||||
"aiter",
|
||||
torch.tensor([0.0391, 0.0391, 0.0430, 0.0488, 0.0469, 0.0566, 0.0586, 0.0566, 0.2402, 0.2539, 0.2637, 0.2812, 0.2930, 0.2910, 0.2891, 0.3164], dtype=torch.bfloat16),
|
||||
True,
|
||||
)
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
INFER_KW = {
|
||||
"prompt": "dance doggo dance",
|
||||
"height": 256,
|
||||
"width": 256,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 3.5,
|
||||
"max_sequence_length": 128,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
|
||||
def _backend_is_probably_supported(pipe, name: str):
|
||||
try:
|
||||
pipe.transformer.set_attention_backend(name)
|
||||
return pipe, True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _check_if_slices_match(output, expected_slice):
|
||||
img = output.images.detach().cpu()
|
||||
generated_slice = img.flatten()
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
assert torch.allclose(generated_slice, expected_slice, atol=1e-4)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def device():
|
||||
if not torch.cuda.is_available():
|
||||
pytest.skip("CUDA is required for these tests.")
|
||||
return torch.device("cuda:0")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def pipe(device):
|
||||
repo_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipe = FluxPipeline.from_pretrained(repo_id, torch_dtype=torch.bfloat16).to(device)
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
return pipe
|
||||
|
||||
|
||||
@pytest.mark.parametrize("backend_name,expected_slice", FORWARD_CASES, ids=[c[0] for c in FORWARD_CASES])
|
||||
def test_forward(pipe, backend_name, expected_slice):
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"backend_name,expected_slice,error_on_recompile",
|
||||
COMPILE_CASES,
|
||||
ids=[c[0] for c in COMPILE_CASES],
|
||||
)
|
||||
def test_forward_with_compile(pipe, backend_name, expected_slice, error_on_recompile):
|
||||
if "native" in backend_name and error_on_recompile and not is_torch_version(">=", "2.9.0"):
|
||||
pytest.xfail(f"Test with {backend_name=} is compatible with a higher version of torch.")
|
||||
|
||||
out = _backend_is_probably_supported(pipe, backend_name)
|
||||
if isinstance(out, bool):
|
||||
pytest.xfail(f"Backend '{backend_name}' not supported in this environment.")
|
||||
|
||||
modified_pipe = out[0]
|
||||
modified_pipe.transformer.compile(fullgraph=True)
|
||||
|
||||
torch.compiler.reset()
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=error_on_recompile),
|
||||
):
|
||||
out = modified_pipe(**INFER_KW, generator=torch.manual_seed(0))
|
||||
|
||||
_check_if_slices_match(out, expected_slice)
|
||||
Reference in New Issue
Block a user