Compare commits

...

1 Commits

Author SHA1 Message Date
sayakpaul
0f8a83fa69 support ltx-2 type masking in flash_3_hub_varlen 2026-03-13 13:59:48 +05:30

View File

@@ -2559,7 +2559,9 @@ def _flash_attention_3_hub(
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
raise ValueError(
"`attn_mask` is not supported for flash-attn 3. Please use the `_flash_3_varlen_hub` backend instead."
)
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None:
@@ -2641,6 +2643,8 @@ def _flash_attention_3_varlen_hub(
_, seq_len_kv, _, _ = key.shape
if attn_mask is not None:
if attn_mask.dtype != torch.bool:
attn_mask = attn_mask > -1
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
@@ -2660,7 +2664,7 @@ def _flash_attention_3_varlen_hub(
value_packed = torch.cat(value_valid, dim=0)
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_VARLEN_HUB].kernel_fn
out, lse, *_ = func(
result = func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2671,6 +2675,11 @@ def _flash_attention_3_varlen_hub(
softmax_scale=scale,
causal=is_causal,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out