mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
2 Commits
custom-blo
...
torch-comp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66ce9ccb03 | ||
|
|
bb443f99dc |
@@ -314,13 +314,11 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
):
|
||||
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
||||
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
||||
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
||||
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
||||
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
|
||||
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
|
||||
max_seqlen_q = seqlens_q.max().item()
|
||||
max_seqlen_k = seqlens_k.max().item()
|
||||
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||
|
||||
|
||||
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
||||
@@ -331,13 +329,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
||||
):
|
||||
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
||||
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
||||
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
||||
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
||||
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
||||
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
|
||||
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
|
||||
max_seqlen_q = seqlens_q.max().item()
|
||||
max_seqlen_k = seqlens_k.max().item()
|
||||
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||
|
||||
|
||||
def _prepare_for_flash_attn_or_sage_varlen(
|
||||
@@ -496,30 +492,18 @@ def _flash_varlen_attention(
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||
out = flash_attn_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
@@ -601,30 +585,18 @@ def _flash_varlen_attention_3(
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||
out, lse, *_ = flash_attn_3_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
@@ -958,30 +930,18 @@ def _sage_varlen_attention(
|
||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||
|
||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
||||
_prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||
)
|
||||
else:
|
||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
key_valid, value_valid = [], []
|
||||
for b in range(batch_size):
|
||||
valid_len = seqlens_k[b]
|
||||
key_valid.append(key[b, :valid_len])
|
||||
value_valid.append(value[b, :valid_len])
|
||||
|
||||
query_packed = query.flatten(0, 1)
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||
|
||||
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||
out = sageattn_varlen(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
cu_seqlens_k=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
|
||||
@@ -263,6 +263,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = FluxAttnProcessor
|
||||
_available_processors = [
|
||||
|
||||
Reference in New Issue
Block a user