|
|
|
|
@@ -34,13 +34,7 @@ from ..utils import (
|
|
|
|
|
get_logger,
|
|
|
|
|
is_aiter_available,
|
|
|
|
|
is_aiter_version,
|
|
|
|
|
is_flash_attn_3_available,
|
|
|
|
|
is_flash_attn_available,
|
|
|
|
|
is_flash_attn_version,
|
|
|
|
|
is_kernels_available,
|
|
|
|
|
is_kernels_version,
|
|
|
|
|
is_sageattention_available,
|
|
|
|
|
is_sageattention_version,
|
|
|
|
|
is_torch_npu_available,
|
|
|
|
|
is_torch_version,
|
|
|
|
|
is_torch_xla_available,
|
|
|
|
|
@@ -56,62 +50,23 @@ from ._modeling_parallel import gather_size_by_comm
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
|
from ._modeling_parallel import ParallelConfig
|
|
|
|
|
|
|
|
|
|
_REQUIRED_FLASH_VERSION = "2.6.3"
|
|
|
|
|
_REQUIRED_AITER_VERSION = "0.1.5"
|
|
|
|
|
_REQUIRED_SAGE_VERSION = "2.1.1"
|
|
|
|
|
_REQUIRED_FLEX_VERSION = "2.5.0"
|
|
|
|
|
_REQUIRED_XLA_VERSION = "2.2"
|
|
|
|
|
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
|
|
|
|
|
|
|
|
|
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
|
|
|
|
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
|
|
|
|
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
|
|
|
|
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
|
|
|
|
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
|
|
|
|
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
|
|
|
|
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
|
|
|
|
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN:
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_func = None
|
|
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
_wrapped_flash_attn_backward = None
|
|
|
|
|
_wrapped_flash_attn_forward = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN_3:
|
|
|
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
|
|
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_3_func = None
|
|
|
|
|
flash_attn_3_varlen_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_AITER_ATTN:
|
|
|
|
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
|
|
|
|
else:
|
|
|
|
|
aiter_flash_attn_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_SAGE_ATTN:
|
|
|
|
|
from sageattention import (
|
|
|
|
|
sageattn,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton,
|
|
|
|
|
sageattn_varlen,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
sageattn = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
|
|
|
|
sageattn_varlen = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLEX_ATTN:
|
|
|
|
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
|
|
|
|
@@ -137,27 +92,6 @@ if _CAN_USE_XFORMERS_ATTN:
|
|
|
|
|
else:
|
|
|
|
|
xops = None
|
|
|
|
|
|
|
|
|
|
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
|
|
|
|
|
if torch.__version__ >= "2.4.0":
|
|
|
|
|
_custom_op = torch.library.custom_op
|
|
|
|
|
_register_fake = torch.library.register_fake
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
|
|
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
|
|
|
|
def wrap(func):
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
return wrap if fn is None else fn
|
|
|
|
|
|
|
|
|
|
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
|
|
|
|
def wrap(func):
|
|
|
|
|
return func
|
|
|
|
|
|
|
|
|
|
return wrap if fn is None else fn
|
|
|
|
|
|
|
|
|
|
_custom_op = custom_op_no_op
|
|
|
|
|
_register_fake = register_fake_no_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
@@ -266,7 +200,6 @@ class _HubKernelConfig:
|
|
|
|
|
repo_id: str
|
|
|
|
|
function_attr: str
|
|
|
|
|
revision: str | None = None
|
|
|
|
|
version: int | None = None
|
|
|
|
|
kernel_fn: Callable | None = None
|
|
|
|
|
wrapped_forward_attr: str | None = None
|
|
|
|
|
wrapped_backward_attr: str | None = None
|
|
|
|
|
@@ -276,31 +209,27 @@ class _HubKernelConfig:
|
|
|
|
|
|
|
|
|
|
# Registry for hub-based attention kernels
|
|
|
|
|
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
|
|
|
|
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
|
|
|
|
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
version=1,
|
|
|
|
|
# revision="fake-ops-return-probs",
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_func",
|
|
|
|
|
version=1,
|
|
|
|
|
revision=None,
|
|
|
|
|
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
|
|
|
|
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
version=1,
|
|
|
|
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/sage-attention",
|
|
|
|
|
function_attr="sageattn",
|
|
|
|
|
version=1,
|
|
|
|
|
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -310,11 +239,11 @@ def attention_backend(backend: str | AttentionBackendName = AttentionBackendName
|
|
|
|
|
"""
|
|
|
|
|
Context manager to set the active attention backend.
|
|
|
|
|
"""
|
|
|
|
|
if backend not in _AttentionBackendRegistry._backends:
|
|
|
|
|
raise ValueError(f"Backend {backend} is not registered.")
|
|
|
|
|
|
|
|
|
|
backend = AttentionBackendName(backend)
|
|
|
|
|
_check_attention_backend_requirements(backend)
|
|
|
|
|
|
|
|
|
|
if backend not in _AttentionBackendRegistry._backends:
|
|
|
|
|
raise ValueError(f"Backend {backend} is not registered.")
|
|
|
|
|
_maybe_download_kernel_for_backend(backend)
|
|
|
|
|
|
|
|
|
|
old_backend = _AttentionBackendRegistry._active_backend
|
|
|
|
|
@@ -448,16 +377,32 @@ def _check_shape(
|
|
|
|
|
|
|
|
|
|
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
|
|
|
|
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
|
|
|
|
if not _CAN_USE_FLASH_ATTN:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"The '{backend.value}' attention backend has been removed. "
|
|
|
|
|
f"Please use 'flash_hub' or 'flash_varlen_hub' instead, which load the flash-attn kernel from the Hub. "
|
|
|
|
|
f"Install the required package with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
|
|
|
|
if not _CAN_USE_FLASH_ATTN_3:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
|
|
|
|
)
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"The '{backend.value}' attention backend has been removed. "
|
|
|
|
|
f"Please use '_flash_3_hub' or '_flash_3_varlen_hub' instead, which load the flash-attn-3 kernel from the Hub. "
|
|
|
|
|
f"Install the required package with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend in [
|
|
|
|
|
AttentionBackendName.SAGE,
|
|
|
|
|
AttentionBackendName.SAGE_VARLEN,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
|
|
|
|
]:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"The '{backend.value}' attention backend has been removed. "
|
|
|
|
|
f"Please use 'sage_hub' instead, which loads the SageAttention kernel from the Hub. "
|
|
|
|
|
f"Install the required package with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend in [
|
|
|
|
|
AttentionBackendName.FLASH_HUB,
|
|
|
|
|
@@ -470,10 +415,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
if not is_kernels_version(">=", "0.12"):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend == AttentionBackendName.AITER:
|
|
|
|
|
if not _CAN_USE_AITER_ATTN:
|
|
|
|
|
@@ -481,19 +422,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
|
|
|
|
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend in [
|
|
|
|
|
AttentionBackendName.SAGE,
|
|
|
|
|
AttentionBackendName.SAGE_VARLEN,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
|
|
|
|
]:
|
|
|
|
|
if not _CAN_USE_SAGE_ATTN:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend == AttentionBackendName.FLEX:
|
|
|
|
|
if not _CAN_USE_FLEX_ATTN:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
@@ -662,78 +590,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===== torch op registrations =====
|
|
|
|
|
# Registrations are required for fullgraph tracing compatibility
|
|
|
|
|
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
|
|
|
|
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
|
|
|
|
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
|
|
|
|
def _wrapped_flash_attn_3(
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
k: torch.Tensor,
|
|
|
|
|
v: torch.Tensor,
|
|
|
|
|
softmax_scale: float | None = None,
|
|
|
|
|
causal: bool = False,
|
|
|
|
|
qv: torch.Tensor | None = None,
|
|
|
|
|
q_descale: torch.Tensor | None = None,
|
|
|
|
|
k_descale: torch.Tensor | None = None,
|
|
|
|
|
v_descale: torch.Tensor | None = None,
|
|
|
|
|
attention_chunk: int = 0,
|
|
|
|
|
softcap: float = 0.0,
|
|
|
|
|
num_splits: int = 1,
|
|
|
|
|
pack_gqa: bool | None = None,
|
|
|
|
|
deterministic: bool = False,
|
|
|
|
|
sm_margin: int = 0,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
# Hardcoded for now because pytorch does not support tuple/int type hints
|
|
|
|
|
window_size = (-1, -1)
|
|
|
|
|
out, lse, *_ = flash_attn_3_func(
|
|
|
|
|
q=q,
|
|
|
|
|
k=k,
|
|
|
|
|
v=v,
|
|
|
|
|
softmax_scale=softmax_scale,
|
|
|
|
|
causal=causal,
|
|
|
|
|
qv=qv,
|
|
|
|
|
q_descale=q_descale,
|
|
|
|
|
k_descale=k_descale,
|
|
|
|
|
v_descale=v_descale,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
attention_chunk=attention_chunk,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=num_splits,
|
|
|
|
|
pack_gqa=pack_gqa,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=sm_margin,
|
|
|
|
|
)
|
|
|
|
|
lse = lse.permute(0, 2, 1)
|
|
|
|
|
return out, lse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
|
|
|
|
|
def _(
|
|
|
|
|
q: torch.Tensor,
|
|
|
|
|
k: torch.Tensor,
|
|
|
|
|
v: torch.Tensor,
|
|
|
|
|
softmax_scale: float | None = None,
|
|
|
|
|
causal: bool = False,
|
|
|
|
|
qv: torch.Tensor | None = None,
|
|
|
|
|
q_descale: torch.Tensor | None = None,
|
|
|
|
|
k_descale: torch.Tensor | None = None,
|
|
|
|
|
v_descale: torch.Tensor | None = None,
|
|
|
|
|
attention_chunk: int = 0,
|
|
|
|
|
softcap: float = 0.0,
|
|
|
|
|
num_splits: int = 1,
|
|
|
|
|
pack_gqa: bool | None = None,
|
|
|
|
|
deterministic: bool = False,
|
|
|
|
|
sm_margin: int = 0,
|
|
|
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
|
|
|
window_size = (-1, -1) # noqa: F841
|
|
|
|
|
# A lot of the parameters here are not yet used in any way within diffusers.
|
|
|
|
|
# We can safely ignore for now and keep the fake op shape propagation simple.
|
|
|
|
|
batch_size, seq_len, num_heads, head_dim = q.shape
|
|
|
|
|
lse_shape = (batch_size, seq_len, num_heads)
|
|
|
|
|
return torch.empty_like(q), q.new_empty(lse_shape)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1005,107 +861,6 @@ def _native_flash_attention_backward_op(
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
|
|
|
|
|
def _flash_attention_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
enable_gqa: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_save_ctx: bool = True,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
):
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
|
|
|
|
|
if enable_gqa:
|
|
|
|
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
|
|
|
|
|
|
|
|
|
|
# Hardcoded for now
|
|
|
|
|
window_size = (-1, -1)
|
|
|
|
|
softcap = 0.0
|
|
|
|
|
alibi_slopes = None
|
|
|
|
|
deterministic = False
|
|
|
|
|
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
|
|
|
|
|
|
|
|
|
if scale is None:
|
|
|
|
|
scale = query.shape[-1] ** (-0.5)
|
|
|
|
|
|
|
|
|
|
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
|
|
|
|
|
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
|
|
|
|
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
|
|
|
|
|
|
|
|
|
with torch.set_grad_enabled(grad_enabled):
|
|
|
|
|
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
dropout_p,
|
|
|
|
|
scale,
|
|
|
|
|
is_causal,
|
|
|
|
|
window_size[0],
|
|
|
|
|
window_size[1],
|
|
|
|
|
softcap,
|
|
|
|
|
alibi_slopes,
|
|
|
|
|
return_lse,
|
|
|
|
|
)
|
|
|
|
|
lse = lse.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
if _save_ctx:
|
|
|
|
|
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.scale = scale
|
|
|
|
|
ctx.is_causal = is_causal
|
|
|
|
|
ctx.window_size = window_size
|
|
|
|
|
ctx.softcap = softcap
|
|
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
|
|
ctx.deterministic = deterministic
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_backward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
grad_out: torch.Tensor,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
|
|
|
|
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
|
|
|
|
|
|
|
|
|
lse_d = _wrapped_flash_attn_backward( # noqa: F841
|
|
|
|
|
grad_out,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
out,
|
|
|
|
|
lse,
|
|
|
|
|
grad_query,
|
|
|
|
|
grad_key,
|
|
|
|
|
grad_value,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.scale,
|
|
|
|
|
ctx.is_causal,
|
|
|
|
|
ctx.window_size[0],
|
|
|
|
|
ctx.window_size[1],
|
|
|
|
|
ctx.softcap,
|
|
|
|
|
ctx.alibi_slopes,
|
|
|
|
|
ctx.deterministic,
|
|
|
|
|
rng_state,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Head dimension may have been padded
|
|
|
|
|
grad_query = grad_query[..., : grad_out.shape[-1]]
|
|
|
|
|
grad_key = grad_key[..., : grad_out.shape[-1]]
|
|
|
|
|
grad_value = grad_value[..., : grad_out.shape[-1]]
|
|
|
|
|
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_hub_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -1337,44 +1092,6 @@ def _flash_attention_3_hub_backward_op(
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sage_attention_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
enable_gqa: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_save_ctx: bool = True,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
):
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
|
|
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
|
|
|
|
if enable_gqa:
|
|
|
|
|
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
|
|
|
|
|
|
|
|
|
out = sageattn(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
lse = None
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
lse = lse.permute(0, 2, 1)
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sage_attention_hub_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -2215,59 +1932,6 @@ def _templated_context_parallel_attention(
|
|
|
|
|
# ===== Attention backends =====
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.FLASH,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
supports_context_parallel=True,
|
|
|
|
|
)
|
|
|
|
|
def _flash_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
lse = None
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
|
|
|
|
|
|
|
|
|
if _parallel_config is None:
|
|
|
|
|
out = flash_attn_func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
return_attn_probs=return_lse,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
else:
|
|
|
|
|
out = _templated_context_parallel_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
None,
|
|
|
|
|
dropout_p,
|
|
|
|
|
is_causal,
|
|
|
|
|
scale,
|
|
|
|
|
False,
|
|
|
|
|
return_lse,
|
|
|
|
|
forward_op=_flash_attention_forward_op,
|
|
|
|
|
backward_op=_flash_attention_backward_op,
|
|
|
|
|
_parallel_config=_parallel_config,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse = out
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.FLASH_HUB,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
@@ -2379,88 +2043,6 @@ def _flash_varlen_attention_hub(
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.FLASH_VARLEN,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _flash_varlen_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
batch_size, seq_len_q, _, _ = query.shape
|
|
|
|
|
_, seq_len_kv, _, _ = key.shape
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
|
|
|
|
|
|
|
|
|
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
|
|
|
|
_prepare_for_flash_attn_or_sage_varlen(
|
|
|
|
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
key_valid, value_valid = [], []
|
|
|
|
|
for b in range(batch_size):
|
|
|
|
|
valid_len = seqlens_k[b]
|
|
|
|
|
key_valid.append(key[b, :valid_len])
|
|
|
|
|
value_valid.append(value[b, :valid_len])
|
|
|
|
|
|
|
|
|
|
query_packed = query.flatten(0, 1)
|
|
|
|
|
key_packed = torch.cat(key_valid, dim=0)
|
|
|
|
|
value_packed = torch.cat(value_valid, dim=0)
|
|
|
|
|
|
|
|
|
|
out = flash_attn_varlen_func(
|
|
|
|
|
q=query_packed,
|
|
|
|
|
k=key_packed,
|
|
|
|
|
v=value_packed,
|
|
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
|
max_seqlen_q=max_seqlen_q,
|
|
|
|
|
max_seqlen_k=max_seqlen_k,
|
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
return_attn_probs=return_lse,
|
|
|
|
|
)
|
|
|
|
|
out = out.unflatten(0, (batch_size, -1))
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._FLASH_3,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _flash_attention_3(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
|
|
|
|
|
|
|
|
|
out, lse = _wrapped_flash_attn_3(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
)
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._FLASH_3_HUB,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
@@ -2597,58 +2179,6 @@ def _flash_attention_3_varlen_hub(
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._FLASH_VARLEN_3,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _flash_varlen_attention_3(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
batch_size, seq_len_q, _, _ = query.shape
|
|
|
|
|
_, seq_len_kv, _, _ = key.shape
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
|
|
|
|
|
|
|
|
|
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
|
|
|
|
_prepare_for_flash_attn_or_sage_varlen(
|
|
|
|
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
key_valid, value_valid = [], []
|
|
|
|
|
for b in range(batch_size):
|
|
|
|
|
valid_len = seqlens_k[b]
|
|
|
|
|
key_valid.append(key[b, :valid_len])
|
|
|
|
|
value_valid.append(value[b, :valid_len])
|
|
|
|
|
|
|
|
|
|
query_packed = query.flatten(0, 1)
|
|
|
|
|
key_packed = torch.cat(key_valid, dim=0)
|
|
|
|
|
value_packed = torch.cat(value_valid, dim=0)
|
|
|
|
|
|
|
|
|
|
out, lse, *_ = flash_attn_3_varlen_func(
|
|
|
|
|
q=query_packed,
|
|
|
|
|
k=key_packed,
|
|
|
|
|
v=value_packed,
|
|
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
|
max_seqlen_q=max_seqlen_q,
|
|
|
|
|
max_seqlen_k=max_seqlen_k,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
)
|
|
|
|
|
out = out.unflatten(0, (batch_size, -1))
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.AITER,
|
|
|
|
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
@@ -3118,57 +2648,6 @@ def _native_xla_attention(
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.SAGE,
|
|
|
|
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
supports_context_parallel=True,
|
|
|
|
|
)
|
|
|
|
|
def _sage_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for sage attention")
|
|
|
|
|
lse = None
|
|
|
|
|
if _parallel_config is None:
|
|
|
|
|
out = sageattn(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
else:
|
|
|
|
|
out = _templated_context_parallel_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
None,
|
|
|
|
|
0.0,
|
|
|
|
|
is_causal,
|
|
|
|
|
scale,
|
|
|
|
|
False,
|
|
|
|
|
return_lse,
|
|
|
|
|
forward_op=_sage_attention_forward_op,
|
|
|
|
|
backward_op=_sage_attention_backward_op,
|
|
|
|
|
_parallel_config=_parallel_config,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse = out
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.SAGE_HUB,
|
|
|
|
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
@@ -3221,169 +2700,6 @@ def _sage_attention_hub(
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.SAGE_VARLEN,
|
|
|
|
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _sage_varlen_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if return_lse:
|
|
|
|
|
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
|
|
|
|
|
|
|
|
|
|
batch_size, seq_len_q, _, _ = query.shape
|
|
|
|
|
_, seq_len_kv, _, _ = key.shape
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
|
|
|
|
|
|
|
|
|
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
|
|
|
|
_prepare_for_flash_attn_or_sage_varlen(
|
|
|
|
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
|
|
|
)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
key_valid, value_valid = [], []
|
|
|
|
|
for b in range(batch_size):
|
|
|
|
|
valid_len = seqlens_k[b]
|
|
|
|
|
key_valid.append(key[b, :valid_len])
|
|
|
|
|
value_valid.append(value[b, :valid_len])
|
|
|
|
|
|
|
|
|
|
query_packed = query.flatten(0, 1)
|
|
|
|
|
key_packed = torch.cat(key_valid, dim=0)
|
|
|
|
|
value_packed = torch.cat(value_valid, dim=0)
|
|
|
|
|
|
|
|
|
|
out = sageattn_varlen(
|
|
|
|
|
q=query_packed,
|
|
|
|
|
k=key_packed,
|
|
|
|
|
v=value_packed,
|
|
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
|
|
max_seqlen_q=max_seqlen_q,
|
|
|
|
|
max_seqlen_k=max_seqlen_k,
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
)
|
|
|
|
|
out = out.unflatten(0, (batch_size, -1))
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
|
|
|
|
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _sage_qk_int8_pv_fp8_cuda_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for sage attention")
|
|
|
|
|
return sageattn_qk_int8_pv_fp8_cuda(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
|
|
|
|
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for sage attention")
|
|
|
|
|
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
|
|
|
|
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _sage_qk_int8_pv_fp16_cuda_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for sage attention")
|
|
|
|
|
return sageattn_qk_int8_pv_fp16_cuda(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
|
|
|
|
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
|
|
|
|
)
|
|
|
|
|
def _sage_qk_int8_pv_fp16_triton_attention(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for sage attention")
|
|
|
|
|
return sageattn_qk_int8_pv_fp16_triton(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.XFORMERS,
|
|
|
|
|
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
|
|
|
|
|