mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[core] start varlen variants for attn backend kernels. (#12765)
* start varlen variants for attn backend kernels. * maybe unflatten heads. * updates * remove unused function. * doc * up
This commit is contained in:
@@ -56,8 +56,8 @@ export_to_video(video, "output.mp4", fps=15)
|
||||
|
||||
- HunyuanVideo1.5 use attention masks with variable-length sequences. For best performance, we recommend using an attention backend that handles padding efficiently.
|
||||
|
||||
- **H100/H800:** `_flash_3_hub` or `_flash_varlen_3`
|
||||
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen`
|
||||
- **H100/H800:** `_flash_3_hub` or `_flash_3_varlen_hub`
|
||||
- **A100/A800/RTX 4090:** `flash_hub` or `flash_varlen_hub`
|
||||
- **Other GPUs:** `sage_hub`
|
||||
|
||||
Refer to the [Attention backends](../../optimization/attention_backends) guide for more details about using a different backend.
|
||||
|
||||
@@ -141,10 +141,12 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from kernels |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `_flash_3_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 from kernels |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_hub` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) from kernels |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
|
||||
@@ -168,10 +168,11 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH = "flash"
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
# _FLASH_VARLEN_3_HUB = "_flash_varlen_3_hub" # not supported yet.
|
||||
_FLASH_3_VARLEN_HUB = "_flash_3_varlen_hub"
|
||||
|
||||
# `aiter`
|
||||
AITER = "aiter"
|
||||
@@ -263,9 +264,17 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
),
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
function_attr="flash_attn_varlen_func",
|
||||
# revision="fake-ops-return-probs",
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
),
|
||||
@@ -425,8 +434,13 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB, AttentionBackendName.FLASH_HUB, AttentionBackendName.SAGE_HUB]:
|
||||
elif backend in [
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
AttentionBackendName.FLASH_VARLEN_HUB,
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
@@ -1387,6 +1401,63 @@ def _flash_attention_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_varlen_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_VARLEN_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -1509,6 +1580,60 @@ def _flash_attention_3_hub(
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_3_varlen_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
|
||||
out, lse, *_ = func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
Reference in New Issue
Block a user