Compare commits

...

2 Commits

Author SHA1 Message Date
Aryan
66ce9ccb03 refacotr 2025-07-22 01:12:07 +02:00
Aryan
bb443f99dc update 2025-07-21 23:56:25 +02:00
2 changed files with 31 additions and 70 deletions

View File

@@ -314,13 +314,11 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
@@ -331,13 +329,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
):
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
max_seqlen_q = seqlens_q.max().item()
max_seqlen_k = seqlens_k.max().item()
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
def _prepare_for_flash_attn_or_sage_varlen(
@@ -496,30 +492,18 @@ def _flash_varlen_attention(
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out = flash_attn_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
@@ -601,30 +585,18 @@ def _flash_varlen_attention_3(
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
@@ -958,30 +930,18 @@ def _sage_varlen_attention(
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
_prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
)
else:
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
key_valid, value_valid = [], []
for b in range(batch_size):
valid_len = seqlens_k[b]
key_valid.append(key[b, :valid_len])
value_valid.append(value[b, :valid_len])
query_packed = query.flatten(0, 1)
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
out = sageattn_varlen(
q=query_packed,
k=key_packed,
v=value_packed,
q=query,
k=key,
v=value,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,

View File

@@ -263,6 +263,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
return hidden_states
@maybe_allow_in_graph
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = FluxAttnProcessor
_available_processors = [