mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-19 19:04:49 +08:00
Compare commits
5 Commits
qwenimage-
...
enable-cp-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dfbd4857b2 | ||
|
|
9bd83616bf | ||
|
|
f732ff1144 | ||
|
|
7a8f85b047 | ||
|
|
82d20e64a5 |
@@ -256,6 +256,10 @@ class _HubKernelConfig:
|
|||||||
function_attr: str
|
function_attr: str
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = None
|
||||||
kernel_fn: Optional[Callable] = None
|
kernel_fn: Optional[Callable] = None
|
||||||
|
wrapped_forward_attr: Optional[str] = None
|
||||||
|
wrapped_backward_attr: Optional[str] = None
|
||||||
|
wrapped_forward_fn: Optional[Callable] = None
|
||||||
|
wrapped_backward_fn: Optional[Callable] = None
|
||||||
|
|
||||||
|
|
||||||
# Registry for hub-based attention kernels
|
# Registry for hub-based attention kernels
|
||||||
@@ -270,7 +274,11 @@ _HUB_KERNELS_REGISTRY: Dict["AttentionBackendName", _HubKernelConfig] = {
|
|||||||
# revision="fake-ops-return-probs",
|
# revision="fake-ops-return-probs",
|
||||||
),
|
),
|
||||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
repo_id="kernels-community/flash-attn2",
|
||||||
|
function_attr="flash_attn_func",
|
||||||
|
revision=None,
|
||||||
|
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||||
|
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||||
),
|
),
|
||||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||||
@@ -594,22 +602,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
|||||||
|
|
||||||
|
|
||||||
# ===== Helpers for downloading kernels =====
|
# ===== Helpers for downloading kernels =====
|
||||||
|
def _resolve_kernel_attr(module, attr_path: str):
|
||||||
|
target = module
|
||||||
|
for attr in attr_path.split("."):
|
||||||
|
if not hasattr(target, attr):
|
||||||
|
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
||||||
|
target = getattr(target, attr)
|
||||||
|
return target
|
||||||
|
|
||||||
|
|
||||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||||
if backend not in _HUB_KERNELS_REGISTRY:
|
if backend not in _HUB_KERNELS_REGISTRY:
|
||||||
return
|
return
|
||||||
config = _HUB_KERNELS_REGISTRY[backend]
|
config = _HUB_KERNELS_REGISTRY[backend]
|
||||||
|
|
||||||
if config.kernel_fn is not None:
|
needs_kernel = config.kernel_fn is None
|
||||||
|
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
||||||
|
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
||||||
|
|
||||||
|
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from kernels import get_kernel
|
from kernels import get_kernel
|
||||||
|
|
||||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||||
kernel_func = getattr(kernel_module, config.function_attr)
|
if needs_kernel:
|
||||||
|
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||||
|
|
||||||
# Cache the downloaded kernel function in the config object
|
if needs_wrapped_forward:
|
||||||
config.kernel_fn = kernel_func
|
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||||
|
|
||||||
|
if needs_wrapped_backward:
|
||||||
|
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||||
@@ -1060,6 +1085,231 @@ def _flash_attention_backward_op(
|
|||||||
return grad_query, grad_key, grad_value
|
return grad_query, grad_key, grad_value
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_hub_forward_op(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
enable_gqa: bool = False,
|
||||||
|
return_lse: bool = False,
|
||||||
|
_save_ctx: bool = True,
|
||||||
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
|
):
|
||||||
|
if attn_mask is not None:
|
||||||
|
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||||
|
if enable_gqa:
|
||||||
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
||||||
|
|
||||||
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||||
|
wrapped_forward_fn = config.wrapped_forward_fn
|
||||||
|
wrapped_backward_fn = config.wrapped_backward_fn
|
||||||
|
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
||||||
|
"for context parallel execution."
|
||||||
|
)
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = query.shape[-1] ** (-0.5)
|
||||||
|
|
||||||
|
window_size = (-1, -1)
|
||||||
|
softcap = 0.0
|
||||||
|
alibi_slopes = None
|
||||||
|
deterministic = False
|
||||||
|
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||||
|
|
||||||
|
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||||
|
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(grad_enabled):
|
||||||
|
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
dropout_p,
|
||||||
|
scale,
|
||||||
|
is_causal,
|
||||||
|
window_size[0],
|
||||||
|
window_size[1],
|
||||||
|
softcap,
|
||||||
|
alibi_slopes,
|
||||||
|
return_lse,
|
||||||
|
)
|
||||||
|
lse = lse.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
if _save_ctx:
|
||||||
|
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||||
|
ctx.dropout_p = dropout_p
|
||||||
|
ctx.scale = scale
|
||||||
|
ctx.is_causal = is_causal
|
||||||
|
ctx.window_size = window_size
|
||||||
|
ctx.softcap = softcap
|
||||||
|
ctx.alibi_slopes = alibi_slopes
|
||||||
|
ctx.deterministic = deterministic
|
||||||
|
|
||||||
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_hub_backward_op(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
grad_out: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||||
|
wrapped_backward_fn = config.wrapped_backward_fn
|
||||||
|
if wrapped_backward_fn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
||||||
|
)
|
||||||
|
|
||||||
|
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||||
|
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||||
|
|
||||||
|
_ = wrapped_backward_fn(
|
||||||
|
grad_out,
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
out,
|
||||||
|
lse,
|
||||||
|
grad_query,
|
||||||
|
grad_key,
|
||||||
|
grad_value,
|
||||||
|
ctx.dropout_p,
|
||||||
|
ctx.scale,
|
||||||
|
ctx.is_causal,
|
||||||
|
ctx.window_size[0],
|
||||||
|
ctx.window_size[1],
|
||||||
|
ctx.softcap,
|
||||||
|
ctx.alibi_slopes,
|
||||||
|
ctx.deterministic,
|
||||||
|
rng_state,
|
||||||
|
)
|
||||||
|
|
||||||
|
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||||
|
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||||
|
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||||
|
|
||||||
|
return grad_query, grad_key, grad_value
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_3_hub_forward_op(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
enable_gqa: bool = False,
|
||||||
|
return_lse: bool = False,
|
||||||
|
_save_ctx: bool = True,
|
||||||
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
|
*,
|
||||||
|
window_size: Tuple[int, int] = (-1, -1),
|
||||||
|
softcap: float = 0.0,
|
||||||
|
num_splits: int = 1,
|
||||||
|
pack_gqa: Optional[bool] = None,
|
||||||
|
deterministic: bool = False,
|
||||||
|
sm_margin: int = 0,
|
||||||
|
):
|
||||||
|
if attn_mask is not None:
|
||||||
|
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
||||||
|
if dropout_p != 0.0:
|
||||||
|
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
||||||
|
if enable_gqa:
|
||||||
|
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||||
|
|
||||||
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||||
|
out = func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
softmax_scale=scale,
|
||||||
|
causal=is_causal,
|
||||||
|
qv=None,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
num_splits=num_splits,
|
||||||
|
pack_gqa=pack_gqa,
|
||||||
|
deterministic=deterministic,
|
||||||
|
sm_margin=sm_margin,
|
||||||
|
return_attn_probs=return_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
lse = None
|
||||||
|
if return_lse:
|
||||||
|
out, lse = out
|
||||||
|
lse = lse.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
if _save_ctx:
|
||||||
|
ctx.save_for_backward(query, key, value)
|
||||||
|
ctx.scale = scale
|
||||||
|
ctx.is_causal = is_causal
|
||||||
|
ctx._hub_kernel = func
|
||||||
|
|
||||||
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
|
def _flash_attention_3_hub_backward_op(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
grad_out: torch.Tensor,
|
||||||
|
*args,
|
||||||
|
window_size: Tuple[int, int] = (-1, -1),
|
||||||
|
softcap: float = 0.0,
|
||||||
|
num_splits: int = 1,
|
||||||
|
pack_gqa: Optional[bool] = None,
|
||||||
|
deterministic: bool = False,
|
||||||
|
sm_margin: int = 0,
|
||||||
|
):
|
||||||
|
query, key, value = ctx.saved_tensors
|
||||||
|
kernel_fn = ctx._hub_kernel
|
||||||
|
with torch.enable_grad():
|
||||||
|
query_r = query.detach().requires_grad_(True)
|
||||||
|
key_r = key.detach().requires_grad_(True)
|
||||||
|
value_r = value.detach().requires_grad_(True)
|
||||||
|
|
||||||
|
out = kernel_fn(
|
||||||
|
q=query_r,
|
||||||
|
k=key_r,
|
||||||
|
v=value_r,
|
||||||
|
softmax_scale=ctx.scale,
|
||||||
|
causal=ctx.is_causal,
|
||||||
|
qv=None,
|
||||||
|
q_descale=None,
|
||||||
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
num_splits=num_splits,
|
||||||
|
pack_gqa=pack_gqa,
|
||||||
|
deterministic=deterministic,
|
||||||
|
sm_margin=sm_margin,
|
||||||
|
return_attn_probs=False,
|
||||||
|
)
|
||||||
|
if isinstance(out, tuple):
|
||||||
|
out = out[0]
|
||||||
|
|
||||||
|
grad_query, grad_key, grad_value = torch.autograd.grad(
|
||||||
|
out,
|
||||||
|
(query_r, key_r, value_r),
|
||||||
|
grad_out,
|
||||||
|
retain_graph=False,
|
||||||
|
allow_unused=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return grad_query, grad_key, grad_value
|
||||||
|
|
||||||
|
|
||||||
def _sage_attention_forward_op(
|
def _sage_attention_forward_op(
|
||||||
ctx: torch.autograd.function.FunctionCtx,
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1106,6 +1356,46 @@ def _sage_attention_backward_op(
|
|||||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||||
|
|
||||||
|
|
||||||
|
def _sage_attention_hub_forward_op(
|
||||||
|
ctx: torch.autograd.function.FunctionCtx,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
|
dropout_p: float = 0.0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
scale: Optional[float] = None,
|
||||||
|
enable_gqa: bool = False,
|
||||||
|
return_lse: bool = False,
|
||||||
|
_save_ctx: bool = True,
|
||||||
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
|
):
|
||||||
|
if attn_mask is not None:
|
||||||
|
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||||
|
if dropout_p > 0.0:
|
||||||
|
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||||
|
if enable_gqa:
|
||||||
|
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||||
|
|
||||||
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||||
|
out = func(
|
||||||
|
q=query,
|
||||||
|
k=key,
|
||||||
|
v=value,
|
||||||
|
tensor_layout="NHD",
|
||||||
|
is_causal=is_causal,
|
||||||
|
sm_scale=scale,
|
||||||
|
return_lse=return_lse,
|
||||||
|
)
|
||||||
|
|
||||||
|
lse = None
|
||||||
|
if return_lse:
|
||||||
|
out, lse, *_ = out
|
||||||
|
lse = lse.permute(0, 2, 1).contiguous()
|
||||||
|
|
||||||
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|
||||||
# ===== Context parallel =====
|
# ===== Context parallel =====
|
||||||
|
|
||||||
|
|
||||||
@@ -1463,7 +1753,7 @@ def _flash_attention(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName.FLASH_HUB,
|
AttentionBackendName.FLASH_HUB,
|
||||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=False,
|
supports_context_parallel=True,
|
||||||
)
|
)
|
||||||
def _flash_attention_hub(
|
def _flash_attention_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1477,17 +1767,35 @@ def _flash_attention_hub(
|
|||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
lse = None
|
lse = None
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||||
out = func(
|
if _parallel_config is None:
|
||||||
q=query,
|
out = func(
|
||||||
k=key,
|
q=query,
|
||||||
v=value,
|
k=key,
|
||||||
dropout_p=dropout_p,
|
v=value,
|
||||||
softmax_scale=scale,
|
dropout_p=dropout_p,
|
||||||
causal=is_causal,
|
softmax_scale=scale,
|
||||||
return_attn_probs=return_lse,
|
causal=is_causal,
|
||||||
)
|
return_attn_probs=return_lse,
|
||||||
if return_lse:
|
)
|
||||||
out, lse, *_ = out
|
if return_lse:
|
||||||
|
out, lse, *_ = out
|
||||||
|
else:
|
||||||
|
out = _templated_context_parallel_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
dropout_p,
|
||||||
|
is_causal,
|
||||||
|
scale,
|
||||||
|
False,
|
||||||
|
return_lse,
|
||||||
|
forward_op=_flash_attention_hub_forward_op,
|
||||||
|
backward_op=_flash_attention_hub_backward_op,
|
||||||
|
_parallel_config=_parallel_config,
|
||||||
|
)
|
||||||
|
if return_lse:
|
||||||
|
out, lse = out
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
@@ -1630,7 +1938,7 @@ def _flash_attention_3(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName._FLASH_3_HUB,
|
AttentionBackendName._FLASH_3_HUB,
|
||||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=False,
|
supports_context_parallel=True,
|
||||||
)
|
)
|
||||||
def _flash_attention_3_hub(
|
def _flash_attention_3_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -1644,31 +1952,65 @@ def _flash_attention_3_hub(
|
|||||||
return_attn_probs: bool = False,
|
return_attn_probs: bool = False,
|
||||||
_parallel_config: Optional["ParallelConfig"] = None,
|
_parallel_config: Optional["ParallelConfig"] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if _parallel_config:
|
|
||||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
|
||||||
|
|
||||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||||
out = func(
|
if _parallel_config is None:
|
||||||
q=query,
|
out = func(
|
||||||
k=key,
|
q=query,
|
||||||
v=value,
|
k=key,
|
||||||
softmax_scale=scale,
|
v=value,
|
||||||
causal=is_causal,
|
softmax_scale=scale,
|
||||||
qv=None,
|
causal=is_causal,
|
||||||
q_descale=None,
|
qv=None,
|
||||||
k_descale=None,
|
q_descale=None,
|
||||||
v_descale=None,
|
k_descale=None,
|
||||||
|
v_descale=None,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
num_splits=1,
|
||||||
|
pack_gqa=None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
sm_margin=0,
|
||||||
|
return_attn_probs=return_attn_probs,
|
||||||
|
)
|
||||||
|
return (out[0], out[1]) if return_attn_probs else out
|
||||||
|
|
||||||
|
forward_op = functools.partial(
|
||||||
|
_flash_attention_3_hub_forward_op,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=softcap,
|
softcap=softcap,
|
||||||
num_splits=1,
|
num_splits=1,
|
||||||
pack_gqa=None,
|
pack_gqa=None,
|
||||||
deterministic=deterministic,
|
deterministic=deterministic,
|
||||||
sm_margin=0,
|
sm_margin=0,
|
||||||
return_attn_probs=return_attn_probs,
|
|
||||||
)
|
)
|
||||||
# When `return_attn_probs` is True, the above returns a tuple of
|
backward_op = functools.partial(
|
||||||
# actual outputs and lse.
|
_flash_attention_3_hub_backward_op,
|
||||||
return (out[0], out[1]) if return_attn_probs else out
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
num_splits=1,
|
||||||
|
pack_gqa=None,
|
||||||
|
deterministic=deterministic,
|
||||||
|
sm_margin=0,
|
||||||
|
)
|
||||||
|
out = _templated_context_parallel_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
is_causal,
|
||||||
|
scale,
|
||||||
|
False,
|
||||||
|
return_attn_probs,
|
||||||
|
forward_op=forward_op,
|
||||||
|
backward_op=backward_op,
|
||||||
|
_parallel_config=_parallel_config,
|
||||||
|
)
|
||||||
|
if return_attn_probs:
|
||||||
|
out, lse = out
|
||||||
|
return out, lse
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
@@ -2217,7 +2559,7 @@ def _sage_attention(
|
|||||||
@_AttentionBackendRegistry.register(
|
@_AttentionBackendRegistry.register(
|
||||||
AttentionBackendName.SAGE_HUB,
|
AttentionBackendName.SAGE_HUB,
|
||||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||||
supports_context_parallel=False,
|
supports_context_parallel=True,
|
||||||
)
|
)
|
||||||
def _sage_attention_hub(
|
def _sage_attention_hub(
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -2242,6 +2584,23 @@ def _sage_attention_hub(
|
|||||||
)
|
)
|
||||||
if return_lse:
|
if return_lse:
|
||||||
out, lse, *_ = out
|
out, lse, *_ = out
|
||||||
|
else:
|
||||||
|
out = _templated_context_parallel_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
is_causal,
|
||||||
|
scale,
|
||||||
|
False,
|
||||||
|
return_lse,
|
||||||
|
forward_op=_sage_attention_hub_forward_op,
|
||||||
|
backward_op=_sage_attention_backward_op,
|
||||||
|
_parallel_config=_parallel_config,
|
||||||
|
)
|
||||||
|
if return_lse:
|
||||||
|
out, lse = out
|
||||||
|
|
||||||
return (out, lse) if return_lse else out
|
return (out, lse) if return_lse else out
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user