Compare commits

...

7 Commits

Author SHA1 Message Date
Sayak Paul
d72bbcc758 Merge branch 'main' into fa-hub 2025-10-22 09:03:36 +05:30
Sayak Paul
029975e54f Merge branch 'main' into fa-hub 2025-10-08 09:14:28 +05:30
Sayak Paul
b0fc7af941 Merge branch 'main' into fa-hub 2025-10-06 10:28:00 +05:30
Sayak Paul
474b99597c Merge branch 'main' into fa-hub 2025-10-03 11:25:37 +05:30
sayakpaul
1b96ed7df3 up 2025-09-26 11:10:00 +05:30
sayakpaul
d252c02d1e support fa (2) through kernels. 2025-09-25 13:24:09 +05:30
sayakpaul
c386f220ea up 2025-09-25 13:05:39 +05:30
3 changed files with 65 additions and 20 deletions

View File

@@ -138,10 +138,11 @@ Refer to the table below for a complete list of available attention backends and
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` |
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` |
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |

View File

@@ -83,12 +83,15 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
fa3_interface_hub = _get_fa3_from_hub()
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
fa_interface_hub = _get_fa_from_hub()
flash_attn_func_hub = fa_interface_hub.flash_attn_func
else:
flash_attn_3_func_hub = None
flash_attn_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -173,6 +176,8 @@ class AttentionBackendName(str, Enum):
# `flash-attn`
FLASH = "flash"
FLASH_VARLEN = "flash_varlen"
FLASH_HUB = "flash_hub"
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -403,15 +408,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
)
# TODO: add support Hub variant of FA3 varlen later
elif backend in [AttentionBackendName._FLASH_3_HUB]:
# TODO: add support Hub variant of FA and FA3 varlen later
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
if not DIFFUSERS_ENABLE_HUB_KERNELS:
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
)
if not is_kernels_available():
raise RuntimeError(
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
)
elif backend in [
@@ -1228,6 +1233,35 @@ def _flash_attention(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
)
def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
) -> torch.Tensor:
lse = None
out = flash_attn_func_hub(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_VARLEN,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -2,22 +2,32 @@ from ..utils import get_logger
from .import_utils import is_kernels_available
if is_kernels_available():
from kernels import get_kernel
logger = get_logger(__name__)
_DEFAULT_HUB_IDS = {
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
"fa": ("kernels-community/flash-attn", {}),
}
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
def _get_from_hub(key: str):
if not is_kernels_available():
return None
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
try:
return get_kernel(hub_id, **kwargs)
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
raise
def _get_fa3_from_hub():
if not is_kernels_available():
return None
else:
from kernels import get_kernel
return _get_from_hub("fa3")
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
raise
def _get_fa_from_hub():
return _get_from_hub("fa")