mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-23 21:04:56 +08:00
Compare commits
2 Commits
remove-unn
...
torch-comp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66ce9ccb03 | ||
|
|
bb443f99dc |
@@ -314,13 +314,11 @@ def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
|||||||
):
|
):
|
||||||
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
||||||
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
seqlens_k = torch.full((batch_size,), seq_len_kv, dtype=torch.int32, device=device)
|
||||||
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
|
||||||
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
|
||||||
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
|
||||||
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
|
||||||
max_seqlen_q = seqlens_q.max().item()
|
max_seqlen_q = seqlens_q.max().item()
|
||||||
max_seqlen_k = seqlens_k.max().item()
|
max_seqlen_k = seqlens_k.max().item()
|
||||||
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
||||||
@@ -331,13 +329,11 @@ def _prepare_for_flash_attn_or_sage_varlen_with_mask(
|
|||||||
):
|
):
|
||||||
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
seqlens_q = torch.full((batch_size,), seq_len_q, dtype=torch.int32, device=device)
|
||||||
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
seqlens_k = attn_mask.sum(dim=1, dtype=torch.int32)
|
||||||
cu_seqlens_q = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
cu_seqlens_q = torch.nn.functional.pad(torch.cumsum(seqlens_q, dim=0, dtype=torch.int32), (1, 0))
|
||||||
cu_seqlens_k = torch.zeros(batch_size + 1, dtype=torch.int32, device=device)
|
cu_seqlens_k = torch.nn.functional.pad(torch.cumsum(seqlens_k, dim=0, dtype=torch.int32), (1, 0))
|
||||||
cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0)
|
|
||||||
cu_seqlens_k[1:] = torch.cumsum(seqlens_k, dim=0)
|
|
||||||
max_seqlen_q = seqlens_q.max().item()
|
max_seqlen_q = seqlens_q.max().item()
|
||||||
max_seqlen_k = seqlens_k.max().item()
|
max_seqlen_k = seqlens_k.max().item()
|
||||||
return (seqlens_q, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
return (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k)
|
||||||
|
|
||||||
|
|
||||||
def _prepare_for_flash_attn_or_sage_varlen(
|
def _prepare_for_flash_attn_or_sage_varlen(
|
||||||
@@ -496,30 +492,18 @@ def _flash_varlen_attention(
|
|||||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||||
|
|
||||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||||
_prepare_for_flash_attn_or_sage_varlen(
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
|
||||||
|
|
||||||
key_valid, value_valid = [], []
|
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||||
for b in range(batch_size):
|
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||||
valid_len = seqlens_k[b]
|
|
||||||
key_valid.append(key[b, :valid_len])
|
|
||||||
value_valid.append(value[b, :valid_len])
|
|
||||||
|
|
||||||
query_packed = query.flatten(0, 1)
|
|
||||||
key_packed = torch.cat(key_valid, dim=0)
|
|
||||||
value_packed = torch.cat(value_valid, dim=0)
|
|
||||||
|
|
||||||
|
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||||
out = flash_attn_varlen_func(
|
out = flash_attn_varlen_func(
|
||||||
q=query_packed,
|
q=query,
|
||||||
k=key_packed,
|
k=key,
|
||||||
v=value_packed,
|
v=value,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
@@ -601,30 +585,18 @@ def _flash_varlen_attention_3(
|
|||||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||||
|
|
||||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||||
_prepare_for_flash_attn_or_sage_varlen(
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
|
||||||
|
|
||||||
key_valid, value_valid = [], []
|
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||||
for b in range(batch_size):
|
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||||
valid_len = seqlens_k[b]
|
|
||||||
key_valid.append(key[b, :valid_len])
|
|
||||||
value_valid.append(value[b, :valid_len])
|
|
||||||
|
|
||||||
query_packed = query.flatten(0, 1)
|
|
||||||
key_packed = torch.cat(key_valid, dim=0)
|
|
||||||
value_packed = torch.cat(value_valid, dim=0)
|
|
||||||
|
|
||||||
|
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||||
out, lse, *_ = flash_attn_3_varlen_func(
|
out, lse, *_ = flash_attn_3_varlen_func(
|
||||||
q=query_packed,
|
q=query,
|
||||||
k=key_packed,
|
k=key,
|
||||||
v=value_packed,
|
v=value,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
@@ -958,30 +930,18 @@ def _sage_varlen_attention(
|
|||||||
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
attn_mask = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv)
|
||||||
|
|
||||||
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
if any(x is None for x in (cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k)):
|
||||||
(_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = (
|
(cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = _prepare_for_flash_attn_or_sage_varlen(
|
||||||
_prepare_for_flash_attn_or_sage_varlen(
|
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
||||||
batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
else:
|
|
||||||
seqlens_k = torch.full((batch_size,), max_seqlen_k, dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
|
||||||
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
|
||||||
|
|
||||||
key_valid, value_valid = [], []
|
cu_seqlens_q = cu_seqlens_q.to(dtype=torch.int32, device=query.device)
|
||||||
for b in range(batch_size):
|
cu_seqlens_k = cu_seqlens_k.to(dtype=torch.int32, device=query.device)
|
||||||
valid_len = seqlens_k[b]
|
|
||||||
key_valid.append(key[b, :valid_len])
|
|
||||||
value_valid.append(value[b, :valid_len])
|
|
||||||
|
|
||||||
query_packed = query.flatten(0, 1)
|
|
||||||
key_packed = torch.cat(key_valid, dim=0)
|
|
||||||
value_packed = torch.cat(value_valid, dim=0)
|
|
||||||
|
|
||||||
|
query, key, value = (x.flatten(0, 1) for x in (query, key, value))
|
||||||
out = sageattn_varlen(
|
out = sageattn_varlen(
|
||||||
q=query_packed,
|
q=query,
|
||||||
k=key_packed,
|
k=key,
|
||||||
v=value_packed,
|
v=value,
|
||||||
cu_seqlens_q=cu_seqlens_q,
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
cu_seqlens_k=cu_seqlens_k,
|
cu_seqlens_k=cu_seqlens_k,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
|
|||||||
@@ -263,6 +263,7 @@ class FluxIPAdapterAttnProcessor(torch.nn.Module):
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
@maybe_allow_in_graph
|
||||||
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
class FluxAttention(torch.nn.Module, AttentionModuleMixin):
|
||||||
_default_processor_cls = FluxAttnProcessor
|
_default_processor_cls = FluxAttnProcessor
|
||||||
_available_processors = [
|
_available_processors = [
|
||||||
|
|||||||
Reference in New Issue
Block a user