mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
7 Commits
custom-blo
...
fa-hub
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d72bbcc758 | ||
|
|
029975e54f | ||
|
|
b0fc7af941 | ||
|
|
474b99597c | ||
|
|
1b96ed7df3 | ||
|
|
d252c02d1e | ||
|
|
c386f220ea |
@@ -138,10 +138,11 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `_native_npu` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | NPU-optimized attention |
|
||||
| `_native_xla` | [PyTorch native](https://docs.pytorch.org/docs/stable/generated/torch.nn.attention.SDPBackend.html#torch.nn.attention.SDPBackend) | XLA-optimized attention |
|
||||
| `flash` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 |
|
||||
| `flash_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-2 from `kernels` |
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from `kernels` |
|
||||
| `sage` | [SageAttention](https://github.com/thu-ml/SageAttention) | Quantized attention (INT8 QK) |
|
||||
| `sage_varlen` | [SageAttention](https://github.com/thu-ml/SageAttention) | Variable length SageAttention |
|
||||
| `_sage_qk_int8_pv_fp8_cuda` | [SageAttention](https://github.com/thu-ml/SageAttention) | INT8 QK + FP8 PV (CUDA) |
|
||||
|
||||
@@ -83,12 +83,15 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub, _get_fa_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
fa3_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = fa3_interface_hub.flash_attn_func
|
||||
fa_interface_hub = _get_fa_from_hub()
|
||||
flash_attn_func_hub = fa_interface_hub.flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
flash_attn_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -173,6 +176,8 @@ class AttentionBackendName(str, Enum):
|
||||
# `flash-attn`
|
||||
FLASH = "flash"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_HUB = "flash_hub"
|
||||
# FLASH_VARLEN_HUB = "flash_varlen_hub" # not supported yet.
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -403,15 +408,15 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
# TODO: add support Hub variant of FA and FA3 varlen later
|
||||
elif backend in [AttentionBackendName.FLASH_HUB, AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Flash Attention Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
@@ -1228,6 +1233,35 @@ def _flash_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
out = flash_attn_func_hub(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -2,22 +2,32 @@ from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
|
||||
if is_kernels_available():
|
||||
from kernels import get_kernel
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
_DEFAULT_HUB_IDS = {
|
||||
"fa3": ("kernels-community/flash-attn3", {"revision": "fake-ops-return-probs"}),
|
||||
"fa": ("kernels-community/flash-attn", {}),
|
||||
}
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
|
||||
def _get_from_hub(key: str):
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
|
||||
hub_id, kwargs = _DEFAULT_HUB_IDS[key]
|
||||
try:
|
||||
return get_kernel(hub_id, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{hub_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
return _get_from_hub("fa3")
|
||||
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
def _get_fa_from_hub():
|
||||
return _get_from_hub("fa")
|
||||
|
||||
Reference in New Issue
Block a user