mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
6 Commits
cache-docs
...
sage-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23c173ea58 | ||
|
|
3688c9d443 | ||
|
|
d3441340b9 | ||
|
|
18c3e8ee0c | ||
|
|
f630dab8a2 | ||
|
|
e9ea1c5b2c |
@@ -17,7 +17,8 @@ import functools
|
||||
import inspect
|
||||
import math
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -83,12 +84,20 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
||||
raise ImportError(
|
||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||
)
|
||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
||||
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
|
||||
from ..utils.sage_utils import _get_sage_attn_fn_for_device
|
||||
|
||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
||||
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
|
||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||
|
||||
sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
|
||||
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
|
||||
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
|
||||
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])
|
||||
|
||||
else:
|
||||
flash_attn_3_func_hub = None
|
||||
sage_attn_func_hub = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
@@ -162,10 +171,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
# - CP with sage attention, flex, xformers, other missing backends
|
||||
# - Add support for normal and CP training with backends that don't support it yet
|
||||
|
||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
||||
|
||||
|
||||
class AttentionBackendName(str, Enum):
|
||||
# EAGER = "eager"
|
||||
@@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum):
|
||||
|
||||
# `sageattention`
|
||||
SAGE = "sage"
|
||||
SAGE_HUB = "sage_hub"
|
||||
SAGE_VARLEN = "sage_varlen"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||
@@ -1756,6 +1762,31 @@ def _sage_attention(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sage_attn_func_hub(q=query, k=key, v=value)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -6,18 +6,25 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
|
||||
_KERNEL_REVISION = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
|
||||
_DEFAULT_HUB_ID_SAGE: "compile",
|
||||
}
|
||||
|
||||
|
||||
def _get_fa3_from_hub():
|
||||
def _get_kernel_from_hub(kernel_id):
|
||||
if not is_kernels_available():
|
||||
return None
|
||||
else:
|
||||
from kernels import get_kernel
|
||||
|
||||
try:
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
||||
return flash_attn_3_hub
|
||||
if kernel_id not in _KERNEL_REVISION:
|
||||
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
|
||||
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
|
||||
return kernel_hub
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
||||
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
|
||||
raise
|
||||
|
||||
137
src/diffusers/utils/sage_utils.py
Normal file
137
src/diffusers/utils/sage_utils.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""
|
||||
Copyright (c) 2024 by SageAttention, The HuggingFace team.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
|
||||
License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
|
||||
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
Modified from
|
||||
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
|
||||
"""
|
||||
|
||||
|
||||
import torch # noqa
|
||||
|
||||
|
||||
SAGE_ATTENTION_DISPATCH = {
|
||||
"sm80": {
|
||||
"func": "sageattn_qk_int8_pv_fp16_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32",
|
||||
},
|
||||
},
|
||||
"sm89": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp16",
|
||||
},
|
||||
},
|
||||
"sm90": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp32",
|
||||
},
|
||||
},
|
||||
"sm120": {
|
||||
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||
"kwargs": {
|
||||
"tensor_layout": "NHD",
|
||||
"is_causal": False,
|
||||
"qk_quant_gran": "per_warp",
|
||||
"sm_scale": None,
|
||||
"return_lse": False,
|
||||
"pv_accum_dtype": "fp32+fp16",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_cuda_version():
|
||||
if torch.cuda.is_available():
|
||||
major, minor = torch.cuda.get_device_capability()
|
||||
return major, minor
|
||||
else:
|
||||
raise EnvironmentError("CUDA not found.")
|
||||
|
||||
|
||||
def get_cuda_arch_versions():
|
||||
if not torch.cuda.is_available():
|
||||
EnvironmentError("CUDA not found.")
|
||||
cuda_archs = []
|
||||
for i in range(torch.cuda.device_count()):
|
||||
major, minor = torch.cuda.get_device_capability(i)
|
||||
cuda_archs.append(f"sm{major}{minor}")
|
||||
return cuda_archs
|
||||
|
||||
|
||||
# Unlike the actual implementation, we just maintain function names rather than actual
|
||||
# implementations.
|
||||
def _get_sage_attn_fn_for_device():
|
||||
"""
|
||||
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
|
||||
capability.
|
||||
|
||||
Parameters ---------- q : torch.Tensor
|
||||
The query tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
k : torch.Tensor
|
||||
The key tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
v : torch.Tensor
|
||||
The value tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||
|
||||
tensor_layout : str
|
||||
The tensor layout, either "HND" or "NHD". Default: "HND".
|
||||
|
||||
is_causal : bool
|
||||
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.
|
||||
|
||||
sm_scale : Optional[float]
|
||||
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||
|
||||
return_lse : bool
|
||||
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
|
||||
Default: False.
|
||||
|
||||
Returns ------- torch.Tensor
|
||||
The output tensor. Shape:
|
||||
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||
|
||||
torch.Tensor
|
||||
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
|
||||
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.
|
||||
|
||||
Note ----
|
||||
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
|
||||
- All tensors must be on the same cuda device.
|
||||
"""
|
||||
device_index = torch.cuda.current_device()
|
||||
arch = get_cuda_arch_versions()[device_index]
|
||||
return SAGE_ATTENTION_DISPATCH[arch]
|
||||
Reference in New Issue
Block a user