This commit is contained in:
sayakpaul
2024-07-15 08:21:30 +05:30
parent 3a139f4329
commit 886b85e45d

View File

@@ -66,6 +66,7 @@ class FA3AttnProcessor:
key = key.view(batch_size, -1, attn.heads, head_dim).contiguous()
value = value.view(batch_size, -1, attn.heads, head_dim).contiguous()
# nasty hack to make the head number and head dim compatible with FA3.
if attn.heads ==1 and head_dim == 512:
factor = 8
new_head_dim = head_dim // factor