|
|
|
|
@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
|
|
|
|
_REQUIRED_XLA_VERSION = "2.2"
|
|
|
|
|
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
|
|
|
|
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
|
|
|
|
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
|
|
|
|
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN:
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
|
|
|
|
try:
|
|
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
# Handle ABI mismatch or other import failures gracefully.
|
|
|
|
|
# This can happen when flash_attn was compiled against a different PyTorch version.
|
|
|
|
|
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
|
|
|
|
_CAN_USE_FLASH_ATTN = False
|
|
|
|
|
flash_attn_func = None
|
|
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
_wrapped_flash_attn_backward = None
|
|
|
|
|
_wrapped_flash_attn_forward = None
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_func = None
|
|
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
@@ -83,26 +95,47 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLASH_ATTN_3:
|
|
|
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
|
|
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
|
|
|
try:
|
|
|
|
|
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
|
|
|
|
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_FLASH_ATTN_3 = False
|
|
|
|
|
flash_attn_3_func = None
|
|
|
|
|
flash_attn_3_varlen_func = None
|
|
|
|
|
else:
|
|
|
|
|
flash_attn_3_func = None
|
|
|
|
|
flash_attn_3_varlen_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_AITER_ATTN:
|
|
|
|
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
|
|
|
|
try:
|
|
|
|
|
from aiter import flash_attn_func as aiter_flash_attn_func
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_AITER_ATTN = False
|
|
|
|
|
aiter_flash_attn_func = None
|
|
|
|
|
else:
|
|
|
|
|
aiter_flash_attn_func = None
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_SAGE_ATTN:
|
|
|
|
|
from sageattention import (
|
|
|
|
|
sageattn,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton,
|
|
|
|
|
sageattn_varlen,
|
|
|
|
|
)
|
|
|
|
|
try:
|
|
|
|
|
from sageattention import (
|
|
|
|
|
sageattn,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda,
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton,
|
|
|
|
|
sageattn_varlen,
|
|
|
|
|
)
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_SAGE_ATTN = False
|
|
|
|
|
sageattn = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_triton = None
|
|
|
|
|
sageattn_varlen = None
|
|
|
|
|
else:
|
|
|
|
|
sageattn = None
|
|
|
|
|
sageattn_qk_int8_pv_fp16_cuda = None
|
|
|
|
|
@@ -113,26 +146,48 @@ else:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_FLEX_ATTN:
|
|
|
|
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
|
|
|
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
|
|
|
|
# compiled function.
|
|
|
|
|
import torch.nn.attention.flex_attention as flex_attention
|
|
|
|
|
try:
|
|
|
|
|
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
|
|
|
|
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
|
|
|
|
# compiled function.
|
|
|
|
|
import torch.nn.attention.flex_attention as flex_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_FLEX_ATTN = False
|
|
|
|
|
flex_attention = None
|
|
|
|
|
else:
|
|
|
|
|
flex_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_NPU_ATTN:
|
|
|
|
|
from torch_npu import npu_fusion_attention
|
|
|
|
|
try:
|
|
|
|
|
from torch_npu import npu_fusion_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_NPU_ATTN = False
|
|
|
|
|
npu_fusion_attention = None
|
|
|
|
|
else:
|
|
|
|
|
npu_fusion_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_XLA_ATTN:
|
|
|
|
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
|
|
|
|
try:
|
|
|
|
|
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_XLA_ATTN = False
|
|
|
|
|
xla_flash_attention = None
|
|
|
|
|
else:
|
|
|
|
|
xla_flash_attention = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if _CAN_USE_XFORMERS_ATTN:
|
|
|
|
|
import xformers.ops as xops
|
|
|
|
|
try:
|
|
|
|
|
import xformers.ops as xops
|
|
|
|
|
except (ImportError, OSError, RuntimeError) as e:
|
|
|
|
|
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
|
|
|
|
_CAN_USE_XFORMERS_ATTN = False
|
|
|
|
|
xops = None
|
|
|
|
|
else:
|
|
|
|
|
xops = None
|
|
|
|
|
|
|
|
|
|
@@ -158,8 +213,6 @@ else:
|
|
|
|
|
_register_fake = register_fake_no_op
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger = get_logger(__name__) # pylint: disable=invalid-name
|
|
|
|
|
|
|
|
|
|
# TODO(aryan): Add support for the following:
|
|
|
|
|
# - Sage Attention++
|
|
|
|
|
# - block sparse, radial and other attention methods
|
|
|
|
|
|