mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
[core] Refactor hub attn kernels (#12475)
* refactor how attention kernels from hub are used. * up * refactor according to Dhruv's ideas. Co-authored-by: Dhruv Nair <dhruv@huggingface.co> * empty Co-authored-by: Dhruv Nair <dhruv@huggingface.co> * empty Co-authored-by: Dhruv Nair <dhruv@huggingface.co> * empty Co-authored-by: dn6 <dhruv@huggingface.co> * up --------- Co-authored-by: Dhruv Nair <dhruv@huggingface.co> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ import contextlib
|
||||
import functools
|
||||
import inspect
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
||||
@@ -42,7 +43,7 @@ from ..utils import (
|
||||
is_xformers_available,
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS, DIFFUSERS_ENABLE_HUB_KERNELS
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -82,24 +83,11 @@ else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
if not is_kernels_available():
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
@@ -261,6 +249,25 @@ class _AttentionBackendRegistry:
|
||||
return supports_context_parallel
|
||||
|
||||
|
||||
@dataclass
|
||||
class _HubKernelConfig:
|
||||
"""Configuration for downloading and using a hub-based attention kernel."""
|
||||
|
||||
repo_id: str
|
||||
function_attr: str
|
||||
revision: Optional[str] = None
|
||||
kernel_fn: Optional[Callable] = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBackendName.NATIVE):
|
||||
"""
|
||||
@@ -415,13 +422,9 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
|
||||
# TODO: add support Hub variant of FA3 varlen later
|
||||
elif backend in [AttentionBackendName._FLASH_3_HUB]:
|
||||
if not DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
|
||||
)
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 Hub backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
@@ -571,6 +574,29 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
if config.kernel_fn is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
@@ -1418,7 +1444,8 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
out = flash_attn_3_func_hub(
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
|
||||
@@ -595,7 +595,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
attention as backend.
|
||||
"""
|
||||
from .attention import AttentionModuleMixin
|
||||
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
|
||||
from .attention_dispatch import (
|
||||
AttentionBackendName,
|
||||
_check_attention_backend_requirements,
|
||||
_maybe_download_kernel_for_backend,
|
||||
)
|
||||
|
||||
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
||||
from .attention_processor import Attention, MochiAttention
|
||||
@@ -606,8 +610,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
||||
if backend not in available_backends:
|
||||
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
||||
for module in self.modules():
|
||||
|
||||
@@ -46,7 +46,6 @@ DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0").upper() in ENV_V
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
||||
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_ENABLE_HUB_KERNELS = os.environ.get("DIFFUSERS_ENABLE_HUB_KERNELS", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
from ..utils import get_logger
|
||||
from .import_utils import is_kernels_available
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
raise
|
||||
@@ -7,7 +7,6 @@ To run this test suite:
|
||||
|
||||
```bash
|
||||
export RUN_ATTENTION_BACKEND_TESTS=yes
|
||||
export DIFFUSERS_ENABLE_HUB_KERNELS=yes
|
||||
|
||||
pytest tests/others/test_attention_backends.py
|
||||
```
|
||||
|
||||
Reference in New Issue
Block a user