From e976dc9affc0944669ee9f02252f307a5e782c2f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 19 Feb 2026 18:52:13 +0530 Subject: [PATCH] use dedicated wrappers from fa3 for cp. --- .claude/settings.local.json | 8 ++ src/diffusers/models/attention_dispatch.py | 155 ++++++++++++--------- 2 files changed, 98 insertions(+), 65 deletions(-) create mode 100644 .claude/settings.local.json diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000000..2f398ec3f7 --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,8 @@ +{ + "permissions": { + "allow": [ + "WebFetch(domain:github.com)", + "WebFetch(domain:raw.githubusercontent.com)" + ] + } +} diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 90ffcac80d..065c9fc51a 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -276,7 +276,11 @@ class _HubKernelConfig: _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = { # TODO: temporary revision for now. Remove when merged upstream into `main`. AttentionBackendName._FLASH_3_HUB: _HubKernelConfig( - repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs" + repo_id="kernels-community/flash-attn3", + function_attr="flash_attn_func", + revision="fake-ops-return-probs", + wrapped_forward_attr="flash_attn_interface._flash_attn_forward", + wrapped_backward_attr="flash_attn_interface._flash_attn_backward", ), AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig( repo_id="kernels-community/flash-attn3", @@ -1237,36 +1241,62 @@ def _flash_attention_3_hub_forward_op( if enable_gqa: raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.") - func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn - out = func( - q=query, - k=key, - v=value, - softmax_scale=scale, + config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] + wrapped_forward_fn = config.wrapped_forward_fn + if wrapped_forward_fn is None: + raise RuntimeError( + "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` " + "for context parallel execution." + ) + + if scale is None: + scale = query.shape[-1] ** (-0.5) + + out, softmax_lse, *_ = wrapped_forward_fn( + query, + key, + value, + None, + None, # k_new, v_new + None, # qv + None, # out + None, + None, + None, # cu_seqlens_q/k/k_new + None, + None, # seqused_q/k + None, + None, # max_seqlen_q/k + None, + None, + None, # page_table, kv_batch_idx, leftpad_k + None, + None, + None, # rotary_cos/sin, seqlens_rotary + None, + None, + None, # q_descale, k_descale, v_descale + scale, causal=is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, + window_size_left=window_size[0], + window_size_right=window_size[1], + attention_chunk=0, softcap=softcap, num_splits=num_splits, pack_gqa=pack_gqa, - deterministic=deterministic, sm_margin=sm_margin, - return_attn_probs=return_lse, ) - lse = None - if return_lse: - out, lse = out - lse = lse.permute(0, 2, 1).contiguous() + lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None if _save_ctx: - ctx.save_for_backward(query, key, value) + ctx.save_for_backward(query, key, value, out, softmax_lse) ctx.scale = scale ctx.is_causal = is_causal - ctx._hub_kernel = func + ctx.window_size = window_size + ctx.softcap = softcap + ctx.deterministic = deterministic + ctx.sm_margin = sm_margin return (out, lse) if return_lse else out @@ -1275,54 +1305,49 @@ def _flash_attention_3_hub_backward_op( ctx: torch.autograd.function.FunctionCtx, grad_out: torch.Tensor, *args, - window_size: tuple[int, int] = (-1, -1), - softcap: float = 0.0, - num_splits: int = 1, - pack_gqa: bool | None = None, - deterministic: bool = False, - sm_margin: int = 0, + **kwargs, ): - query, key, value = ctx.saved_tensors - kernel_fn = ctx._hub_kernel - # NOTE: Unlike the FA2 hub kernel, the FA3 hub kernel does not expose separate wrapped forward/backward - # primitives (no `wrapped_forward_attr`/`wrapped_backward_attr` in its `_HubKernelConfig`). We - # therefore rerun the forward pass under `torch.enable_grad()` and differentiate through it with - # `torch.autograd.grad()`. This is a second forward pass during backward; it can be avoided once - # the FA3 hub exposes a dedicated fused backward kernel (analogous to `_wrapped_flash_attn_backward` - # in the FA2 hub), at which point this can be refactored to match `_flash_attention_hub_backward_op`. - with torch.enable_grad(): - query_r = query.detach().requires_grad_(True) - key_r = key.detach().requires_grad_(True) - value_r = value.detach().requires_grad_(True) - - out = kernel_fn( - q=query_r, - k=key_r, - v=value_r, - softmax_scale=ctx.scale, - causal=ctx.is_causal, - qv=None, - q_descale=None, - k_descale=None, - v_descale=None, - window_size=window_size, - softcap=softcap, - num_splits=num_splits, - pack_gqa=pack_gqa, - deterministic=deterministic, - sm_margin=sm_margin, - return_attn_probs=False, + config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB] + wrapped_backward_fn = config.wrapped_backward_fn + if wrapped_backward_fn is None: + raise RuntimeError( + "Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` " + "for context parallel execution." ) - if isinstance(out, tuple): - out = out[0] - grad_query, grad_key, grad_value = torch.autograd.grad( - out, - (query_r, key_r, value_r), - grad_out, - retain_graph=False, - allow_unused=False, - ) + query, key, value, out, softmax_lse = ctx.saved_tensors + grad_query = torch.empty_like(query) + grad_key = torch.empty_like(key) + grad_value = torch.empty_like(value) + + wrapped_backward_fn( + grad_out, + query, + key, + value, + out, + softmax_lse, + None, + None, # cu_seqlens_q, cu_seqlens_k + None, + None, # seqused_q, seqused_k + None, + None, # max_seqlen_q, max_seqlen_k + grad_query, + grad_key, + grad_value, + ctx.scale, + ctx.is_causal, + ctx.window_size[0], + ctx.window_size[1], + ctx.softcap, + ctx.deterministic, + ctx.sm_margin, + ) + + grad_query = grad_query[..., : grad_out.shape[-1]] + grad_key = grad_key[..., : grad_out.shape[-1]] + grad_value = grad_value[..., : grad_out.shape[-1]] return grad_query, grad_key, grad_value