mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 07:24:32 +08:00
Compare commits
6 Commits
cache-docs
...
sage-kerne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
23c173ea58 | ||
|
|
3688c9d443 | ||
|
|
d3441340b9 | ||
|
|
18c3e8ee0c | ||
|
|
f630dab8a2 | ||
|
|
e9ea1c5b2c |
@@ -17,7 +17,8 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
from functools import partial
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@@ -83,12 +84,20 @@ if DIFFUSERS_ENABLE_HUB_KERNELS:
|
|||||||
raise ImportError(
|
raise ImportError(
|
||||||
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
"To use FA3 kernel for your hardware from the Hub, the `kernels` library must be installed. Install with `pip install kernels`."
|
||||||
)
|
)
|
||||||
from ..utils.kernels_utils import _get_fa3_from_hub
|
from ..utils.kernels_utils import _DEFAULT_HUB_ID_FA3, _DEFAULT_HUB_ID_SAGE, _get_kernel_from_hub
|
||||||
|
from ..utils.sage_utils import _get_sage_attn_fn_for_device
|
||||||
|
|
||||||
flash_attn_interface_hub = _get_fa3_from_hub()
|
flash_attn_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_FA3)
|
||||||
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
flash_attn_3_func_hub = flash_attn_interface_hub.flash_attn_func
|
||||||
|
|
||||||
|
sage_interface_hub = _get_kernel_from_hub(_DEFAULT_HUB_ID_SAGE)
|
||||||
|
sage_fn_with_kwargs = _get_sage_attn_fn_for_device()
|
||||||
|
sage_attn_func_hub = getattr(sage_interface_hub, sage_fn_with_kwargs["func"])
|
||||||
|
sage_attn_func_hub = partial(sage_attn_func_hub, **sage_fn_with_kwargs["kwargs"])
|
||||||
|
|
||||||
else:
|
else:
|
||||||
flash_attn_3_func_hub = None
|
flash_attn_3_func_hub = None
|
||||||
|
sage_attn_func_hub = None
|
||||||
|
|
||||||
if _CAN_USE_SAGE_ATTN:
|
if _CAN_USE_SAGE_ATTN:
|
||||||
from sageattention import (
|
from sageattention import (
|
||||||
@@ -162,10 +171,6 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
|||||||
# - CP with sage attention, flex, xformers, other missing backends
|
# - CP with sage attention, flex, xformers, other missing backends
|
||||||
# - Add support for normal and CP training with backends that don't support it yet
|
# - Add support for normal and CP training with backends that don't support it yet
|
||||||
|
|
||||||
_SAGE_ATTENTION_PV_ACCUM_DTYPE = Literal["fp32", "fp32+fp32"]
|
|
||||||
_SAGE_ATTENTION_QK_QUANT_GRAN = Literal["per_thread", "per_warp"]
|
|
||||||
_SAGE_ATTENTION_QUANTIZATION_BACKEND = Literal["cuda", "triton"]
|
|
||||||
|
|
||||||
|
|
||||||
class AttentionBackendName(str, Enum):
|
class AttentionBackendName(str, Enum):
|
||||||
# EAGER = "eager"
|
# EAGER = "eager"
|
||||||
@@ -190,6 +195,7 @@ class AttentionBackendName(str, Enum):
|
|||||||
|
|
||||||
# `sageattention`
|
# `sageattention`
|
||||||
SAGE = "sage"
|
SAGE = "sage"
|
||||||
|
SAGE_HUB = "sage_hub"
|
||||||
SAGE_VARLEN = "sage_varlen"
|
SAGE_VARLEN = "sage_varlen"
|
||||||
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
_SAGE_QK_INT8_PV_FP8_CUDA = "_sage_qk_int8_pv_fp8_cuda"
|
||||||
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
_SAGE_QK_INT8_PV_FP8_CUDA_SM90 = "_sage_qk_int8_pv_fp8_cuda_sm90"
|
||||||
@@ -1756,6 +1762,31 @@ def _sage_attention(
|
|||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
|
@_AttentionBackendRegistry.register(
|
||||||
|
AttentionBackendName.SAGE_HUB,
|
||||||
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
|
supports_context_parallel=False,
|
||||||
|
)
|
||||||
|
def _sage_attention_hub(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
return_lse: bool = False,
|
||||||
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
lse = None
|
||||||
|
if _parallel_config is None:
|
||||||
|
out = sage_attn_func_hub(q=query, k=key, v=value)
|
||||||
|
if return_lse:
|
||||||
|
out, lse, *_ = out
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("SAGE attention doesn't yet support parallelism.")
|
||||||
|
|
||||||
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName.SAGE_VARLEN,
|
AttentionBackendName.SAGE_VARLEN,
|
||||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
|
|||||||
@@ -6,18 +6,25 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
_DEFAULT_HUB_ID_FA3 = "kernels-community/flash-attn3"
|
||||||
|
_DEFAULT_HUB_ID_SAGE = "kernels-community/sage_attention"
|
||||||
|
_KERNEL_REVISION = {
|
||||||
|
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||||
|
_DEFAULT_HUB_ID_FA3: "fake-ops-return-probs",
|
||||||
|
_DEFAULT_HUB_ID_SAGE: "compile",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def _get_fa3_from_hub():
|
def _get_kernel_from_hub(kernel_id):
|
||||||
if not is_kernels_available():
|
if not is_kernels_available():
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
if kernel_id not in _KERNEL_REVISION:
|
||||||
flash_attn_3_hub = get_kernel(_DEFAULT_HUB_ID_FA3, revision="fake-ops-return-probs")
|
raise NotImplementedError(f"{kernel_id} is not implemented in Diffusers.")
|
||||||
return flash_attn_3_hub
|
kernel_hub = get_kernel(kernel_id, revision=_KERNEL_REVISION.get(kernel_id))
|
||||||
|
return kernel_hub
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while fetching kernel '{_DEFAULT_HUB_ID_FA3}' from the Hub: {e}")
|
logger.error(f"An error occurred while fetching kernel '{kernel_id}' from the Hub: {e}")
|
||||||
raise
|
raise
|
||||||
|
|||||||
137
src/diffusers/utils/sage_utils.py
Normal file
137
src/diffusers/utils/sage_utils.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2024 by SageAttention, The HuggingFace team.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the
|
||||||
|
License. You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
|
||||||
|
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||||
|
specific language governing permissions and limitations under the License.
|
||||||
|
"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
Modified from
|
||||||
|
https://github.com/thu-ml/SageAttention/blob/68de3797d163b89d28f9a38026c3b7313f6940d2/sageattention/core.py
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
import torch # noqa
|
||||||
|
|
||||||
|
|
||||||
|
SAGE_ATTENTION_DISPATCH = {
|
||||||
|
"sm80": {
|
||||||
|
"func": "sageattn_qk_int8_pv_fp16_cuda",
|
||||||
|
"kwargs": {
|
||||||
|
"tensor_layout": "NHD",
|
||||||
|
"is_causal": False,
|
||||||
|
"sm_scale": None,
|
||||||
|
"return_lse": False,
|
||||||
|
"pv_accum_dtype": "fp32",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sm89": {
|
||||||
|
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"kwargs": {
|
||||||
|
"tensor_layout": "NHD",
|
||||||
|
"is_causal": False,
|
||||||
|
"sm_scale": None,
|
||||||
|
"return_lse": False,
|
||||||
|
"pv_accum_dtype": "fp32+fp16",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sm90": {
|
||||||
|
"func": "sageattn_qk_int8_pv_fp8_cuda_sm90",
|
||||||
|
"kwargs": {
|
||||||
|
"tensor_layout": "NHD",
|
||||||
|
"is_causal": False,
|
||||||
|
"sm_scale": None,
|
||||||
|
"return_lse": False,
|
||||||
|
"pv_accum_dtype": "fp32+fp32",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"sm120": {
|
||||||
|
"func": "sageattn_qk_int8_pv_fp8_cuda",
|
||||||
|
"kwargs": {
|
||||||
|
"tensor_layout": "NHD",
|
||||||
|
"is_causal": False,
|
||||||
|
"qk_quant_gran": "per_warp",
|
||||||
|
"sm_scale": None,
|
||||||
|
"return_lse": False,
|
||||||
|
"pv_accum_dtype": "fp32+fp16",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_version():
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
major, minor = torch.cuda.get_device_capability()
|
||||||
|
return major, minor
|
||||||
|
else:
|
||||||
|
raise EnvironmentError("CUDA not found.")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cuda_arch_versions():
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
EnvironmentError("CUDA not found.")
|
||||||
|
cuda_archs = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
major, minor = torch.cuda.get_device_capability(i)
|
||||||
|
cuda_archs.append(f"sm{major}{minor}")
|
||||||
|
return cuda_archs
|
||||||
|
|
||||||
|
|
||||||
|
# Unlike the actual implementation, we just maintain function names rather than actual
|
||||||
|
# implementations.
|
||||||
|
def _get_sage_attn_fn_for_device():
|
||||||
|
"""
|
||||||
|
Automatically selects the appropriate implementation of the SageAttention kernel based on the GPU compute
|
||||||
|
capability.
|
||||||
|
|
||||||
|
Parameters ---------- q : torch.Tensor
|
||||||
|
The query tensor. Shape:
|
||||||
|
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||||
|
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||||
|
|
||||||
|
k : torch.Tensor
|
||||||
|
The key tensor. Shape:
|
||||||
|
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||||
|
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||||
|
|
||||||
|
v : torch.Tensor
|
||||||
|
The value tensor. Shape:
|
||||||
|
- If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``.
|
||||||
|
- If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``.
|
||||||
|
|
||||||
|
tensor_layout : str
|
||||||
|
The tensor layout, either "HND" or "NHD". Default: "HND".
|
||||||
|
|
||||||
|
is_causal : bool
|
||||||
|
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. Default: False.
|
||||||
|
|
||||||
|
sm_scale : Optional[float]
|
||||||
|
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
|
||||||
|
|
||||||
|
return_lse : bool
|
||||||
|
Whether to return the log sum of the exponentiated attention weights. Used for cases like Ring Attention.
|
||||||
|
Default: False.
|
||||||
|
|
||||||
|
Returns ------- torch.Tensor
|
||||||
|
The output tensor. Shape:
|
||||||
|
- If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``.
|
||||||
|
- If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``.
|
||||||
|
|
||||||
|
torch.Tensor
|
||||||
|
The logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax normalization factor). Shape:
|
||||||
|
``[batch_size, num_qo_heads, qo_len]``. Only returned if `return_lse` is True.
|
||||||
|
|
||||||
|
Note ----
|
||||||
|
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
|
||||||
|
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16``
|
||||||
|
- All tensors must be on the same cuda device.
|
||||||
|
"""
|
||||||
|
device_index = torch.cuda.current_device()
|
||||||
|
arch = get_cuda_arch_versions()[device_index]
|
||||||
|
return SAGE_ATTENTION_DISPATCH[arch]
|
||||||
Reference in New Issue
Block a user