mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-21 02:10:42 +08:00
Compare commits
6 Commits
klein-peft
...
update-ker
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b7ed4c8dc | ||
|
|
67f4691cab | ||
|
|
e10fe61303 | ||
|
|
348350cf24 | ||
|
|
af35e3806c | ||
|
|
d6bc647932 |
@@ -38,6 +38,7 @@ from ..utils import (
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_kernels_available,
|
||||
is_kernels_version,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_torch_npu_available,
|
||||
@@ -265,6 +266,7 @@ class _HubKernelConfig:
|
||||
repo_id: str
|
||||
function_attr: str
|
||||
revision: str | None = None
|
||||
version: int | None = None
|
||||
kernel_fn: Callable | None = None
|
||||
wrapped_forward_attr: str | None = None
|
||||
wrapped_backward_attr: str | None = None
|
||||
@@ -274,27 +276,31 @@ class _HubKernelConfig:
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", version=1
|
||||
),
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
function_attr="flash_attn_varlen_func",
|
||||
# revision="fake-ops-return-probs",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
version=1,
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_varlen_func",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.SAGE_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/sage_attention", function_attr="sageattn", revision=None
|
||||
repo_id="kernels-community/sage-attention",
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -464,6 +470,10 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' is not usable because the `kernels` package isn't available. Please install it with `pip install kernels`."
|
||||
)
|
||||
if not is_kernels_version(">=", "0.12"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
|
||||
@@ -86,6 +86,7 @@ from .import_utils import (
|
||||
is_inflect_available,
|
||||
is_invisible_watermark_available,
|
||||
is_kernels_available,
|
||||
is_kernels_version,
|
||||
is_kornia_available,
|
||||
is_librosa_available,
|
||||
is_matplotlib_available,
|
||||
|
||||
@@ -724,6 +724,22 @@ def is_transformers_version(operation: str, version: str):
|
||||
return compare_versions(parse(_transformers_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_kernels_version(operation: str, version: str):
|
||||
"""
|
||||
Compares the current Kernels version to a given reference with an operation.
|
||||
|
||||
Args:
|
||||
operation (`str`):
|
||||
A string representation of an operator, such as `">"` or `"<="`
|
||||
version (`str`):
|
||||
A version string
|
||||
"""
|
||||
if not _kernels_available:
|
||||
return False
|
||||
return compare_versions(parse(_kernels_version), operation, version)
|
||||
|
||||
|
||||
@cache
|
||||
def is_hf_hub_version(operation: str, version: str):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user