mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-22 18:56:30 +08:00
Compare commits
2 Commits
transforme
...
remove-non
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5e87c38b29 | ||
|
|
a80b19218b |
@@ -5472,6 +5472,10 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
|
||||
if is_peft_format:
|
||||
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
|
||||
|
||||
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if is_ai_toolkit:
|
||||
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
|
||||
|
||||
@@ -34,12 +34,7 @@ from ..utils import (
|
||||
get_logger,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_flash_attn_3_available,
|
||||
is_flash_attn_available,
|
||||
is_flash_attn_version,
|
||||
is_kernels_available,
|
||||
is_sageattention_available,
|
||||
is_sageattention_version,
|
||||
is_torch_npu_available,
|
||||
is_torch_version,
|
||||
is_torch_xla_available,
|
||||
@@ -55,62 +50,23 @@ from ._modeling_parallel import gather_size_by_comm
|
||||
if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_AITER_VERSION = "0.1.5"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
_REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
_CAN_USE_SAGE_ATTN = is_sageattention_available() and is_sageattention_version(">=", _REQUIRED_SAGE_VERSION)
|
||||
_CAN_USE_FLEX_ATTN = is_torch_version(">=", _REQUIRED_FLEX_VERSION)
|
||||
_CAN_USE_NPU_ATTN = is_torch_npu_available()
|
||||
_CAN_USE_XLA_ATTN = is_torch_xla_available() and is_torch_xla_version(">=", _REQUIRED_XLA_VERSION)
|
||||
_CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _REQUIRED_XFORMERS_VERSION)
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_varlen = None
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
@@ -136,27 +92,6 @@ if _CAN_USE_XFORMERS_ATTN:
|
||||
else:
|
||||
xops = None
|
||||
|
||||
# Version guard for PyTorch compatibility - custom_op was added in PyTorch 2.4
|
||||
if torch.__version__ >= "2.4.0":
|
||||
_custom_op = torch.library.custom_op
|
||||
_register_fake = torch.library.register_fake
|
||||
else:
|
||||
|
||||
def custom_op_no_op(name, fn=None, /, *, mutates_args, device_types=None, schema=None):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
def register_fake_no_op(op, fn=None, /, *, lib=None, _stacklevel=1):
|
||||
def wrap(func):
|
||||
return func
|
||||
|
||||
return wrap if fn is None else fn
|
||||
|
||||
_custom_op = custom_op_no_op
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -304,11 +239,11 @@ def attention_backend(backend: str | AttentionBackendName = AttentionBackendName
|
||||
"""
|
||||
Context manager to set the active attention backend.
|
||||
"""
|
||||
if backend not in _AttentionBackendRegistry._backends:
|
||||
raise ValueError(f"Backend {backend} is not registered.")
|
||||
|
||||
backend = AttentionBackendName(backend)
|
||||
_check_attention_backend_requirements(backend)
|
||||
|
||||
if backend not in _AttentionBackendRegistry._backends:
|
||||
raise ValueError(f"Backend {backend} is not registered.")
|
||||
_maybe_download_kernel_for_backend(backend)
|
||||
|
||||
old_backend = _AttentionBackendRegistry._active_backend
|
||||
@@ -442,16 +377,32 @@ def _check_shape(
|
||||
|
||||
def _check_attention_backend_requirements(backend: AttentionBackendName) -> None:
|
||||
if backend in [AttentionBackendName.FLASH, AttentionBackendName.FLASH_VARLEN]:
|
||||
if not _CAN_USE_FLASH_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `flash-attn>={_REQUIRED_FLASH_VERSION}`."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use 'flash_hub' or 'flash_varlen_hub' instead, which load the flash-attn kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [AttentionBackendName._FLASH_3, AttentionBackendName._FLASH_VARLEN_3]:
|
||||
if not _CAN_USE_FLASH_ATTN_3:
|
||||
raise RuntimeError(
|
||||
f"Flash Attention 3 backend '{backend.value}' is not usable because of missing package or the version is too old. Please build FA3 beta release from source."
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use '_flash_3_hub' or '_flash_3_varlen_hub' instead, which load the flash-attn-3 kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
]:
|
||||
raise RuntimeError(
|
||||
f"The '{backend.value}' attention backend has been removed. "
|
||||
f"Please use 'sage_hub' instead, which loads the SageAttention kernel from the Hub. "
|
||||
f"Install the required package with `pip install kernels`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
@@ -471,19 +422,6 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Aiter Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `aiter>={_REQUIRED_AITER_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend in [
|
||||
AttentionBackendName.SAGE,
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
]:
|
||||
if not _CAN_USE_SAGE_ATTN:
|
||||
raise RuntimeError(
|
||||
f"Sage Attention backend '{backend.value}' is not usable because of missing package or the version is too old. Please install `sageattention>={_REQUIRED_SAGE_VERSION}`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.FLEX:
|
||||
if not _CAN_USE_FLEX_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -652,78 +590,6 @@ def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
raise
|
||||
|
||||
|
||||
# ===== torch op registrations =====
|
||||
# Registrations are required for fullgraph tracing compatibility
|
||||
# TODO: this is only required because the beta release FA3 does not have it. There is a PR adding
|
||||
# this but it was never merged: https://github.com/Dao-AILab/flash-attention/pull/1590
|
||||
@_custom_op("_diffusers_flash_attn_3::_flash_attn_forward", mutates_args=(), device_types="cuda")
|
||||
def _wrapped_flash_attn_3(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
qv: torch.Tensor | None = None,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
qv=qv,
|
||||
q_descale=q_descale,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
attention_chunk=attention_chunk,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
|
||||
|
||||
@_register_fake("_diffusers_flash_attn_3::_flash_attn_forward")
|
||||
def _(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
softmax_scale: float | None = None,
|
||||
causal: bool = False,
|
||||
qv: torch.Tensor | None = None,
|
||||
q_descale: torch.Tensor | None = None,
|
||||
k_descale: torch.Tensor | None = None,
|
||||
v_descale: torch.Tensor | None = None,
|
||||
attention_chunk: int = 0,
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
window_size = (-1, -1) # noqa: F841
|
||||
# A lot of the parameters here are not yet used in any way within diffusers.
|
||||
# We can safely ignore for now and keep the fake op shape propagation simple.
|
||||
batch_size, seq_len, num_heads, head_dim = q.shape
|
||||
lse_shape = (batch_size, seq_len, num_heads)
|
||||
return torch.empty_like(q), q.new_empty(lse_shape)
|
||||
|
||||
|
||||
# ===== Helper functions to use attention backends with templated CP autograd functions =====
|
||||
|
||||
|
||||
@@ -995,107 +861,6 @@ def _native_flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
# Adapted from: https://github.com/Dao-AILab/flash-attention/blob/fd2fc9d85c8e54e5c20436465bca709bc1a6c5a1/flash_attn/flash_attn_interface.py#L807
|
||||
def _flash_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 2.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 2.")
|
||||
|
||||
# Hardcoded for now
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
lse_d = _wrapped_flash_attn_backward( # noqa: F841
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
# Head dimension may have been padded
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1327,44 +1092,6 @@ def _flash_attention_3_hub_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
out = sageattn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -2205,59 +1932,6 @@ def _templated_context_parallel_attention(
|
||||
# ===== Attention backends =====
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_forward_op,
|
||||
backward_op=_flash_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -2369,88 +2043,6 @@ def _flash_varlen_attention_hub(
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_VARLEN,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out = flash_attn_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -2587,58 +2179,6 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _flash_varlen_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out, lse, *_ = flash_attn_3_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.AITER,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -3108,57 +2648,6 @@ def _native_xla_attention(
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
@@ -3211,169 +2700,6 @@ def _sage_attention_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_VARLEN,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
)
|
||||
def _sage_varlen_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Sage varlen backend does not support setting `return_lse=True`.")
|
||||
|
||||
batch_size, seq_len_q, _, _ = query.shape
|
||||
_, seq_len_kv, _, _ = key.shape
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out = sageattn_varlen(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
)
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA,
|
||||
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP8_CUDA_SM90,
|
||||
constraints=[_check_device_cuda_atleast_smXY(9, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_CUDA,
|
||||
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._SAGE_QK_INT8_PV_FP16_TRITON,
|
||||
constraints=[_check_device_cuda_atleast_smXY(8, 0), _check_shape],
|
||||
)
|
||||
def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.XFORMERS,
|
||||
constraints=[_check_attn_mask_or_causal, _check_device, _check_shape],
|
||||
|
||||
Reference in New Issue
Block a user