|
|
|
|
@@ -38,6 +38,7 @@ from ..utils import (
|
|
|
|
|
is_flash_attn_available,
|
|
|
|
|
is_flash_attn_version,
|
|
|
|
|
is_kernels_available,
|
|
|
|
|
is_kernels_version,
|
|
|
|
|
is_sageattention_available,
|
|
|
|
|
is_sageattention_version,
|
|
|
|
|
is_torch_npu_available,
|
|
|
|
|
@@ -62,6 +63,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
|
|
|
|
_REQUIRED_XLA_VERSION = "2.2"
|
|
|
|
|
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
|
|
|
|
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
|
|
|
|
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
|
|
|
|
@@ -73,8 +76,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN:
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
|
|
|
|
try:
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
# Handle ABI mismatch or other import failures gracefully.
|
|
|
|
|
# This can happen when flash_attn was compiled against a different PyTorch version.
|
|
|
|
|
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
|
|
|
|
_CAN_USE_FLASH_ATTN = False
|
|
|
|
|
flash_attn_func = None
|
|
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
_wrapped_flash_attn_backward = None
|
|
|
|
|
_wrapped_flash_attn_forward = None
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_func = None
|
|
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
@@ -83,26 +96,47 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN_3:
|
|
|
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
|
|
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
|
|
|
try:
|
|
|
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
|
|
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_FLASH_ATTN_3 = False
|
|
|
|
|
flash_attn_3_func = None
|
|
|
|
|
flash_attn_3_varlen_func = None
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_3_func = None
|
|
|
|
|
flash_attn_3_varlen_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_AITER_ATTN:
|
|
|
|
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
|
|
|
|
try:
|
|
|
|
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_AITER_ATTN = False
|
|
|
|
|
aiter_flash_attn_func = None
|
|
|
|
|
else:
|
|
|
|
|
aiter_flash_attn_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_SAGE_ATTN:
|
|
|
|
|
from sageattention import (
|
|
|
|
|
sageattn,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton,
|
|
|
|
|
sageattn_varlen,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
from sageattention import (
|
|
|
|
|
sageattn,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton,
|
|
|
|
|
sageattn_varlen,
|
|
|
|
|
)
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_SAGE_ATTN = False
|
|
|
|
|
sageattn = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton = None
|
|
|
|
|
sageattn_varlen = None
|
|
|
|
|
else:
|
|
|
|
|
sageattn = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda = None
|
|
|
|
|
@@ -113,26 +147,48 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLEX_ATTN:
|
|
|
|
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
|
|
|
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
|
|
|
|
# compiled function.
|
|
|
|
|
import torch.nn.attention.flex_attention as flex_attention
|
|
|
|
|
try:
|
|
|
|
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
|
|
|
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
|
|
|
|
# compiled function.
|
|
|
|
|
import torch.nn.attention.flex_attention as flex_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_FLEX_ATTN = False
|
|
|
|
|
flex_attention = None
|
|
|
|
|
else:
|
|
|
|
|
flex_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_NPU_ATTN:
|
|
|
|
|
from torch_npu import npu_fusion_attention
|
|
|
|
|
try:
|
|
|
|
|
from torch_npu import npu_fusion_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_NPU_ATTN = False
|
|
|
|
|
npu_fusion_attention = None
|
|
|
|
|
else:
|
|
|
|
|
npu_fusion_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_XLA_ATTN:
|
|
|
|
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
|
|
|
|
try:
|
|
|
|
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_XLA_ATTN = False
|
|
|
|
|
xla_flash_attention = None
|
|
|
|
|
else:
|
|
|
|
|
xla_flash_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_XFORMERS_ATTN:
|
|
|
|
|
import xformers.ops as xops
|
|
|
|
|
try:
|
|
|
|
|
import xformers.ops as xops
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_XFORMERS_ATTN = False
|
|
|
|
|
xops = None
|
|
|
|
|
else:
|
|
|
|
|
xops = None
|
|
|
|
|
|
|
|
|
|
@@ -158,8 +214,6 @@ else:
|
|
|
|
|
_register_fake = register_fake_no_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
# TODO(aryan): Add support for the following:
|
|
|
|
|
# - Sage Attention++
|
|
|
|
|
# - block sparse, radial and other attention methods
|
|
|
|
|
@@ -265,6 +319,7 @@ class _HubKernelConfig:
|
|
|
|
|
repo_id: str
|
|
|
|
|
function_attr: str
|
|
|
|
|
revision: str | None = None
|
|
|
|
|
version: int | None = None
|
|
|
|
|
kernel_fn: Callable | None = None
|
|
|
|
|
wrapped_forward_attr: str | None = None
|
|
|
|
|
wrapped_backward_attr: str | None = None
|
|
|
|
|
@@ -274,27 +329,31 @@ class _HubKernelConfig:
|
|
|
|
|
|
|
|
|
|
# Registry for hub-based attention kernels
|
|
|
|
|
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
|
|
|
|
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
|
|
|
|
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
# revision="fake-ops-return-probs",
|
|
|
|
|
version=1,
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_func",
|
|
|
|
|
version=1,
|
|
|
|
|
revision=None,
|
|
|
|
|
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
|
|
|
|
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
version=1,
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
|
|
|
|
repo_id="kernels-community/sage-attention",
|
|
|
|
|
function_attr="sageattn",
|
|
|
|
|
version=1,
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -464,6 +523,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
if not is_kernels_version(">=", "0.12"):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend == AttentionBackendName.AITER:
|
|
|
|
|
if not _CAN_USE_AITER_ATTN:
|
|
|
|
|
|