Compare commits

...

6 Commits

Author SHA1 Message Date
Sayak Paul
23c173ea58 Merge branch 'main' into sage-kernels 2025-10-13 10:47:20 +05:30
Sayak Paul
3688c9d443 Merge branch 'main' into sage-kernels 2025-10-08 09:35:09 +05:30
sayakpaul
d3441340b9 support automatic dispatch. 2025-10-07 18:40:12 +05:30
Sayak Paul
18c3e8ee0c Merge branch 'main' into sage-kernels 2025-10-07 14:59:01 +05:30
Sayak Paul
f630dab8a2 Merge branch 'main' into sage-kernels 2025-10-06 19:15:00 +05:30
sayakpaul
e9ea1c5b2c up 2025-10-06 10:47:12 +05:30
3 changed files with 187 additions and 12 deletions

View File

@@ -17,7 +17,8 @@ import functools
import inspect
import math
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -83,12 +84,20 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
raise ImportError(
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
)
from ..utils.kernels_utils import _get_fa3_from_hub
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
from ..utils.sage_utils import _get_sage_attn_fn_for_device
flash_attn_interface_hub = _get_fa3_from_hub()
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])
else:
flash_attn_3_func_hub = None
sage_attn_func_hub = None
if _CAN_USE_SAGE_ATTN:
from sageattention import (
@@ -162,10 +171,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
# - CP with sage attention, flex, xformers, other missing backends
# - Add support for normal and CP training with backends that don't support it yet
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
class AttentionBackendName(str, Enum):
# EAGER = "eager"
@@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum):
# `sageattention`
SAGE = "sage"
SAGE_HUB = "sage_hub"
SAGE_VARLEN = "sage_varlen"
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
@@ -1756,6 +1762,31 @@ def _sage_attention(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if _parallel_config is None:
out = sage_attn_func_hub(q=query, k=key, v=value)
if return_lse:
out, lse, *_ = out
else:
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_VARLEN,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -6,18 +6,25 @@ logger = get_logger(__name__)
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
_KERNEL_REVISION = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
_DEFAULT_HUB_ID_SAGE: "compile",
}
def _get_fa3_from_hub():
def _get_kernel_from_hub(kernel_id):
if not is_kernels_available():
return None
else:
from kernels import get_kernel
try:
# TODO: temporary revision for now. Remove when merged upstream into `main`.
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
return flash_attn_3_hub
if kernel_id not in _KERNEL_REVISION:
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
return kernel_hub
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
raise

View File

@@ -0,0 +1,137 @@
"""
Copyright (c) 2024 by SageAttention, The HuggingFace team.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
"""
"""
Modified from
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
"""
import torch # noqa
SAGE_ATTENTION_DISPATCH = {
"sm80": {
"func": "sageattn_qk_int8_pv_fp16_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32",
},
},
"sm89": {
"func": "sageattn_qk_int8_pv_fp8_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp16",
},
},
"sm90": {
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp32",
},
},
"sm120": {
"func": "sageattn_qk_int8_pv_fp8_cuda",
"kwargs": {
"tensor_layout": "NHD",
"is_causal": False,
"qk_quant_gran": "per_warp",
"sm_scale": None,
"return_lse": False,
"pv_accum_dtype": "fp32+fp16",
},
},
}
def get_cuda_version():
if torch.cuda.is_available():
major, minor = torch.cuda.get_device_capability()
return major, minor
else:
raise EnvironmentError("CUDA not found.")
def get_cuda_arch_versions():
if not torch.cuda.is_available():
EnvironmentError("CUDA not found.")
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
# Unlike the actual implementation, we just maintain function names rather than actual
# implementations.
def _get_sage_attn_fn_for_device():
"""
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
capability.
Parameters ---------- q : torch.Tensor
The query tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
tensor_layout : str
The tensor layout, either "HND" or "NHD". Default: "HND".
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
return_lse : bool
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
Default: False.
Returns ------- torch.Tensor
The output tensor. Shape:
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
torch.Tensor
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.
Note ----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
- All tensors must be on the same cuda device.
"""
device_index = torch.cuda.current_device()
arch = get_cuda_arch_versions()[device_index]
return SAGE_ATTENTION_DISPATCH[arch]