|
|
|
|
@@ -38,7 +38,6 @@ from ..utils import (
|
|
|
|
|
is_flash_attn_available,
|
|
|
|
|
is_flash_attn_version,
|
|
|
|
|
is_kernels_available,
|
|
|
|
|
is_kernels_version,
|
|
|
|
|
is_sageattention_available,
|
|
|
|
|
is_sageattention_version,
|
|
|
|
|
is_torch_npu_available,
|
|
|
|
|
@@ -266,41 +265,28 @@ class _HubKernelConfig:
|
|
|
|
|
repo_id: str
|
|
|
|
|
function_attr: str
|
|
|
|
|
revision: str | None = None
|
|
|
|
|
version: int | None = None
|
|
|
|
|
kernel_fn: Callable | None = None
|
|
|
|
|
wrapped_forward_attr: str | None = None
|
|
|
|
|
wrapped_backward_attr: str | None = None
|
|
|
|
|
wrapped_forward_fn: Callable | None = None
|
|
|
|
|
wrapped_backward_fn: Callable | None = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Registry for hub-based attention kernels
|
|
|
|
|
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
|
|
|
|
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
|
|
|
|
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
|
|
|
|
|
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn3",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
version=1,
|
|
|
|
|
# revision="fake-ops-return-probs",
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_func",
|
|
|
|
|
version=1,
|
|
|
|
|
revision=None,
|
|
|
|
|
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
|
|
|
|
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
|
|
|
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/flash-attn2",
|
|
|
|
|
function_attr="flash_attn_varlen_func",
|
|
|
|
|
version=1,
|
|
|
|
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
|
|
|
|
),
|
|
|
|
|
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
|
|
|
|
repo_id="kernels-community/sage-attention",
|
|
|
|
|
function_attr="sageattn",
|
|
|
|
|
version=1,
|
|
|
|
|
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
|
|
|
|
),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
@@ -470,10 +456,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
|
|
|
|
)
|
|
|
|
|
if not is_kernels_version(">=", "0.12"):
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
elif backend == AttentionBackendName.AITER:
|
|
|
|
|
if not _CAN_USE_AITER_ATTN:
|
|
|
|
|
@@ -623,39 +605,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ===== Helpers for downloading kernels =====
|
|
|
|
|
def _resolve_kernel_attr(module, attr_path: str):
|
|
|
|
|
target = module
|
|
|
|
|
for attr in attr_path.split("."):
|
|
|
|
|
if not hasattr(target, attr):
|
|
|
|
|
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
|
|
|
|
target = getattr(target, attr)
|
|
|
|
|
return target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
|
|
|
|
if backend not in _HUB_KERNELS_REGISTRY:
|
|
|
|
|
return
|
|
|
|
|
config = _HUB_KERNELS_REGISTRY[backend]
|
|
|
|
|
|
|
|
|
|
needs_kernel = config.kernel_fn is None
|
|
|
|
|
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
|
|
|
|
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
|
|
|
|
|
|
|
|
|
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
|
|
|
|
if config.kernel_fn is not None:
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from kernels import get_kernel
|
|
|
|
|
|
|
|
|
|
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
|
|
|
|
if needs_kernel:
|
|
|
|
|
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
|
|
|
|
kernel_func = getattr(kernel_module, config.function_attr)
|
|
|
|
|
|
|
|
|
|
if needs_wrapped_forward:
|
|
|
|
|
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
|
|
|
|
|
|
|
|
|
if needs_wrapped_backward:
|
|
|
|
|
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
|
|
|
|
# Cache the downloaded kernel function in the config object
|
|
|
|
|
config.kernel_fn = kernel_func
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
|
|
|
|
@@ -1106,237 +1071,6 @@ def _flash_attention_backward_op(
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_hub_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
enable_gqa: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_save_ctx: bool = True,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
):
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
|
|
|
|
if enable_gqa:
|
|
|
|
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
|
|
|
|
|
|
|
|
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
|
|
|
|
wrapped_forward_fn = config.wrapped_forward_fn
|
|
|
|
|
wrapped_backward_fn = config.wrapped_backward_fn
|
|
|
|
|
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
|
|
|
|
"for context parallel execution."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if scale is None:
|
|
|
|
|
scale = query.shape[-1] ** (-0.5)
|
|
|
|
|
|
|
|
|
|
window_size = (-1, -1)
|
|
|
|
|
softcap = 0.0
|
|
|
|
|
alibi_slopes = None
|
|
|
|
|
deterministic = False
|
|
|
|
|
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
|
|
|
|
|
|
|
|
|
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
|
|
|
|
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
|
|
|
|
|
|
|
|
|
with torch.set_grad_enabled(grad_enabled):
|
|
|
|
|
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
dropout_p,
|
|
|
|
|
scale,
|
|
|
|
|
is_causal,
|
|
|
|
|
window_size[0],
|
|
|
|
|
window_size[1],
|
|
|
|
|
softcap,
|
|
|
|
|
alibi_slopes,
|
|
|
|
|
return_lse,
|
|
|
|
|
)
|
|
|
|
|
lse = lse.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
if _save_ctx:
|
|
|
|
|
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
|
|
|
|
ctx.dropout_p = dropout_p
|
|
|
|
|
ctx.scale = scale
|
|
|
|
|
ctx.is_causal = is_causal
|
|
|
|
|
ctx.window_size = window_size
|
|
|
|
|
ctx.softcap = softcap
|
|
|
|
|
ctx.alibi_slopes = alibi_slopes
|
|
|
|
|
ctx.deterministic = deterministic
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_hub_backward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
grad_out: torch.Tensor,
|
|
|
|
|
*args,
|
|
|
|
|
**kwargs,
|
|
|
|
|
):
|
|
|
|
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
|
|
|
|
wrapped_backward_fn = config.wrapped_backward_fn
|
|
|
|
|
if wrapped_backward_fn is None:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
|
|
|
|
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
|
|
|
|
|
|
|
|
|
_ = wrapped_backward_fn(
|
|
|
|
|
grad_out,
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
out,
|
|
|
|
|
lse,
|
|
|
|
|
grad_query,
|
|
|
|
|
grad_key,
|
|
|
|
|
grad_value,
|
|
|
|
|
ctx.dropout_p,
|
|
|
|
|
ctx.scale,
|
|
|
|
|
ctx.is_causal,
|
|
|
|
|
ctx.window_size[0],
|
|
|
|
|
ctx.window_size[1],
|
|
|
|
|
ctx.softcap,
|
|
|
|
|
ctx.alibi_slopes,
|
|
|
|
|
ctx.deterministic,
|
|
|
|
|
rng_state,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
grad_query = grad_query[..., : grad_out.shape[-1]]
|
|
|
|
|
grad_key = grad_key[..., : grad_out.shape[-1]]
|
|
|
|
|
grad_value = grad_value[..., : grad_out.shape[-1]]
|
|
|
|
|
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_3_hub_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
enable_gqa: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_save_ctx: bool = True,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
*,
|
|
|
|
|
window_size: tuple[int, int] = (-1, -1),
|
|
|
|
|
softcap: float = 0.0,
|
|
|
|
|
num_splits: int = 1,
|
|
|
|
|
pack_gqa: bool | None = None,
|
|
|
|
|
deterministic: bool = False,
|
|
|
|
|
sm_margin: int = 0,
|
|
|
|
|
):
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
|
|
|
|
if dropout_p != 0.0:
|
|
|
|
|
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
|
|
|
|
if enable_gqa:
|
|
|
|
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
|
|
|
|
|
|
|
|
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
qv=None,
|
|
|
|
|
q_descale=None,
|
|
|
|
|
k_descale=None,
|
|
|
|
|
v_descale=None,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=num_splits,
|
|
|
|
|
pack_gqa=pack_gqa,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=sm_margin,
|
|
|
|
|
return_attn_probs=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
lse = None
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse = out
|
|
|
|
|
lse = lse.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
if _save_ctx:
|
|
|
|
|
ctx.save_for_backward(query, key, value)
|
|
|
|
|
ctx.scale = scale
|
|
|
|
|
ctx.is_causal = is_causal
|
|
|
|
|
ctx._hub_kernel = func
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _flash_attention_3_hub_backward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
grad_out: torch.Tensor,
|
|
|
|
|
*args,
|
|
|
|
|
window_size: tuple[int, int] = (-1, -1),
|
|
|
|
|
softcap: float = 0.0,
|
|
|
|
|
num_splits: int = 1,
|
|
|
|
|
pack_gqa: bool | None = None,
|
|
|
|
|
deterministic: bool = False,
|
|
|
|
|
sm_margin: int = 0,
|
|
|
|
|
):
|
|
|
|
|
query, key, value = ctx.saved_tensors
|
|
|
|
|
kernel_fn = ctx._hub_kernel
|
|
|
|
|
# NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward
|
|
|
|
|
# primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We
|
|
|
|
|
# therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with
|
|
|
|
|
# `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once
|
|
|
|
|
# the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward`
|
|
|
|
|
# in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`.
|
|
|
|
|
with torch.enable_grad():
|
|
|
|
|
query_r = query.detach().requires_grad_(True)
|
|
|
|
|
key_r = key.detach().requires_grad_(True)
|
|
|
|
|
value_r = value.detach().requires_grad_(True)
|
|
|
|
|
|
|
|
|
|
out = kernel_fn(
|
|
|
|
|
q=query_r,
|
|
|
|
|
k=key_r,
|
|
|
|
|
v=value_r,
|
|
|
|
|
softmax_scale=ctx.scale,
|
|
|
|
|
causal=ctx.is_causal,
|
|
|
|
|
qv=None,
|
|
|
|
|
q_descale=None,
|
|
|
|
|
k_descale=None,
|
|
|
|
|
v_descale=None,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=num_splits,
|
|
|
|
|
pack_gqa=pack_gqa,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=sm_margin,
|
|
|
|
|
return_attn_probs=False,
|
|
|
|
|
)
|
|
|
|
|
if isinstance(out, tuple):
|
|
|
|
|
out = out[0]
|
|
|
|
|
|
|
|
|
|
grad_query, grad_key, grad_value = torch.autograd.grad(
|
|
|
|
|
out,
|
|
|
|
|
(query_r, key_r, value_r),
|
|
|
|
|
grad_out,
|
|
|
|
|
retain_graph=False,
|
|
|
|
|
allow_unused=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return grad_query, grad_key, grad_value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sage_attention_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -1375,46 +1109,6 @@ def _sage_attention_forward_op(
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sage_attention_hub_forward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
key: torch.Tensor,
|
|
|
|
|
value: torch.Tensor,
|
|
|
|
|
attn_mask: torch.Tensor | None = None,
|
|
|
|
|
dropout_p: float = 0.0,
|
|
|
|
|
is_causal: bool = False,
|
|
|
|
|
scale: float | None = None,
|
|
|
|
|
enable_gqa: bool = False,
|
|
|
|
|
return_lse: bool = False,
|
|
|
|
|
_save_ctx: bool = True,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
):
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
|
|
|
|
if dropout_p > 0.0:
|
|
|
|
|
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
|
|
|
|
if enable_gqa:
|
|
|
|
|
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
|
|
|
|
|
|
|
|
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
tensor_layout="NHD",
|
|
|
|
|
is_causal=is_causal,
|
|
|
|
|
sm_scale=scale,
|
|
|
|
|
return_lse=return_lse,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
lse = None
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
lse = lse.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sage_attention_backward_op(
|
|
|
|
|
ctx: torch.autograd.function.FunctionCtx,
|
|
|
|
|
grad_out: torch.Tensor,
|
|
|
|
|
@@ -2271,7 +1965,7 @@ def _flash_attention(
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.FLASH_HUB,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
supports_context_parallel=True,
|
|
|
|
|
supports_context_parallel=False,
|
|
|
|
|
)
|
|
|
|
|
def _flash_attention_hub(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -2289,35 +1983,17 @@ def _flash_attention_hub(
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
|
|
|
|
|
|
|
|
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
|
|
|
|
if _parallel_config is None:
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
return_attn_probs=return_lse,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
else:
|
|
|
|
|
out = _templated_context_parallel_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
None,
|
|
|
|
|
dropout_p,
|
|
|
|
|
is_causal,
|
|
|
|
|
scale,
|
|
|
|
|
False,
|
|
|
|
|
return_lse,
|
|
|
|
|
forward_op=_flash_attention_hub_forward_op,
|
|
|
|
|
backward_op=_flash_attention_hub_backward_op,
|
|
|
|
|
_parallel_config=_parallel_config,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse = out
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
dropout_p=dropout_p,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
return_attn_probs=return_lse,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
@@ -2464,7 +2140,7 @@ def _flash_attention_3(
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName._FLASH_3_HUB,
|
|
|
|
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
supports_context_parallel=True,
|
|
|
|
|
supports_context_parallel=False,
|
|
|
|
|
)
|
|
|
|
|
def _flash_attention_3_hub(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -2479,68 +2155,33 @@ def _flash_attention_3_hub(
|
|
|
|
|
return_attn_probs: bool = False,
|
|
|
|
|
_parallel_config: "ParallelConfig" | None = None,
|
|
|
|
|
) -> torch.Tensor:
|
|
|
|
|
if _parallel_config:
|
|
|
|
|
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
|
|
|
|
if attn_mask is not None:
|
|
|
|
|
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
|
|
|
|
|
|
|
|
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
|
|
|
|
if _parallel_config is None:
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
qv=None,
|
|
|
|
|
q_descale=None,
|
|
|
|
|
k_descale=None,
|
|
|
|
|
v_descale=None,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=1,
|
|
|
|
|
pack_gqa=None,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=0,
|
|
|
|
|
return_attn_probs=return_attn_probs,
|
|
|
|
|
)
|
|
|
|
|
return (out[0], out[1]) if return_attn_probs else out
|
|
|
|
|
|
|
|
|
|
forward_op = functools.partial(
|
|
|
|
|
_flash_attention_3_hub_forward_op,
|
|
|
|
|
out = func(
|
|
|
|
|
q=query,
|
|
|
|
|
k=key,
|
|
|
|
|
v=value,
|
|
|
|
|
softmax_scale=scale,
|
|
|
|
|
causal=is_causal,
|
|
|
|
|
qv=None,
|
|
|
|
|
q_descale=None,
|
|
|
|
|
k_descale=None,
|
|
|
|
|
v_descale=None,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=1,
|
|
|
|
|
pack_gqa=None,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=0,
|
|
|
|
|
return_attn_probs=return_attn_probs,
|
|
|
|
|
)
|
|
|
|
|
backward_op = functools.partial(
|
|
|
|
|
_flash_attention_3_hub_backward_op,
|
|
|
|
|
window_size=window_size,
|
|
|
|
|
softcap=softcap,
|
|
|
|
|
num_splits=1,
|
|
|
|
|
pack_gqa=None,
|
|
|
|
|
deterministic=deterministic,
|
|
|
|
|
sm_margin=0,
|
|
|
|
|
)
|
|
|
|
|
out = _templated_context_parallel_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
None,
|
|
|
|
|
0.0,
|
|
|
|
|
is_causal,
|
|
|
|
|
scale,
|
|
|
|
|
False,
|
|
|
|
|
return_attn_probs,
|
|
|
|
|
forward_op=forward_op,
|
|
|
|
|
backward_op=backward_op,
|
|
|
|
|
_parallel_config=_parallel_config,
|
|
|
|
|
)
|
|
|
|
|
if return_attn_probs:
|
|
|
|
|
out, lse = out
|
|
|
|
|
return out, lse
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
# When `return_attn_probs` is True, the above returns a tuple of
|
|
|
|
|
# actual outputs and lse.
|
|
|
|
|
return (out[0], out[1]) if return_attn_probs else out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
@@ -3172,7 +2813,7 @@ def _sage_attention(
|
|
|
|
|
@_AttentionBackendRegistry.register(
|
|
|
|
|
AttentionBackendName.SAGE_HUB,
|
|
|
|
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
|
|
|
|
supports_context_parallel=True,
|
|
|
|
|
supports_context_parallel=False,
|
|
|
|
|
)
|
|
|
|
|
def _sage_attention_hub(
|
|
|
|
|
query: torch.Tensor,
|
|
|
|
|
@@ -3200,23 +2841,6 @@ def _sage_attention_hub(
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse, *_ = out
|
|
|
|
|
else:
|
|
|
|
|
out = _templated_context_parallel_attention(
|
|
|
|
|
query,
|
|
|
|
|
key,
|
|
|
|
|
value,
|
|
|
|
|
None,
|
|
|
|
|
0.0,
|
|
|
|
|
is_causal,
|
|
|
|
|
scale,
|
|
|
|
|
False,
|
|
|
|
|
return_lse,
|
|
|
|
|
forward_op=_sage_attention_hub_forward_op,
|
|
|
|
|
backward_op=_sage_attention_backward_op,
|
|
|
|
|
_parallel_config=_parallel_config,
|
|
|
|
|
)
|
|
|
|
|
if return_lse:
|
|
|
|
|
out, lse = out
|
|
|
|
|
|
|
|
|
|
return (out, lse) if return_lse else out
|
|
|
|
|
|
|
|
|
|
|