Compare commits

...

10 Commits

Author SHA1 Message Date
sayakpaul
6e9f81fa03 up 2025-08-27 13:36:57 +02:00
Sayak Paul
548f56e428 Merge branch 'main' into fa3-from-kernels 2025-08-27 10:08:50 +02:00
Sayak Paul
595ae6bda9 Merge branch 'main' into fa3-from-kernels 2025-08-26 17:01:52 +02:00
sayakpaul
4e69d42287 up 2025-08-26 16:54:48 +02:00
sayakpaul
2bb3796569 up 2025-08-26 12:07:46 +02:00
sayakpaul
87d08798de up 2025-08-26 12:02:49 +02:00
sayakpaul
bc40971210 change to Hub. 2025-08-26 11:33:43 +02:00
sayakpaul
ac43e8497f Merge branch 'main' into fa3-from-kernels 2025-08-26 10:18:26 +02:00
sayakpaul
a0177ebfec up 2025-08-25 18:53:02 +02:00
sayakpaul
827fc1599a feat: try loading fa3 using kernels when available. 2025-08-25 18:06:30 +02:00
3 changed files with 111 additions and 2 deletions

View File

@@ -26,6 +26,7 @@ from ..utils import (
is_flash_attn_3_available,
is_flash_attn_available,
is_flash_attn_version,
is_kernels_available,
is_sageattention_available,
is_sageattention_version,
is_torch_npu_available,
@@ -131,7 +132,6 @@ else:
_custom_op = custom_op_no_op
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
@@ -144,6 +144,9 @@ _SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
flash_attn_3_hub_func = None
__fa3_hub_loaded = False
class AttentionBackendName(str, Enum):
# EAGER = "eager"
@@ -153,6 +156,8 @@ class AttentionBackendName(str, Enum):
FLASH_VARLEN = "flash_varlen"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
# PyTorch native
FLEX = "flex"
@@ -207,6 +212,22 @@ class _AttentionBackendRegistry:
return list(cls._backends.keys())
def _ensure_fa3_hub_loaded():
global __fa3_hub_loaded
if __fa3_hub_loaded:
return
from ..utils.kernels_utils import _get_fa3_from_hub
fa3_hub_module = _get_fa3_from_hub() # doesn't retrigger download if already available.
if fa3_hub_module is None:
raise RuntimeError(
"Failed to load FlashAttention-3 kernels from the Hub. Please ensure the wheel is available for your platform."
)
global flash_attn_3_hub_func
flash_attn_3_hub_func = fa3_hub_module.flash_attn_func
__fa3_hub_loaded = True
@contextlib.contextmanager
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
"""
@@ -351,6 +372,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend in [
AttentionBackendName.SAGE,
AttentionBackendName.SAGE_VARLEN,
@@ -514,6 +542,22 @@ def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torc
return torch.empty_like(query), query.new_empty(lse_shape)
# @_custom_op("vllm_flash_attn3::flash_attn", mutates_args=(), device_types="cuda")
# def _wrapped_flash_attn_3_hub(
# query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
# ) -> Tuple[torch.Tensor, torch.Tensor]:
# out, lse = flash_attn_3_hub_func(query, key, value)
# lse = lse.permute(0, 2, 1)
# return out, lse
# @_register_fake("vllm_flash_attn3::flash_attn")
# def _(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
# batch_size, seq_len, num_heads, head_dim = query.shape
# lse_shape = (batch_size, seq_len, num_heads)
# return torch.empty_like(query), query.new_empty(lse_shape)
# ===== Attention backends =====
@@ -657,6 +701,41 @@ def _flash_attention_3(
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
softcap: float = 0.0,
deterministic: bool = False,
return_attn_probs: bool = False,
) -> torch.Tensor:
out, lse, *_ = flash_attn_3_hub_func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
return (out, lse) if return_attn_probs else out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
attention as backend.
"""
from .attention import AttentionModuleMixin
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
from .attention_dispatch import (
AttentionBackendName,
_check_attention_backend_requirements,
_ensure_fa3_hub_loaded,
)
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
@@ -608,6 +612,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
backend = AttentionBackendName(backend)
_check_attention_backend_requirements(backend)
# TODO: clean this once it gets exhausted.
if "_flash_3_hub" in backend:
# We ensure it's preloaded to reduce overhead and also to avoid compilation errors.
_ensure_fa3_hub_loaded()
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
for module in self.modules():

View File

@@ -0,0 +1,22 @@
from ..utils import get_logger
from .import_utils import is_kernels_available
logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel
try:
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise