Fix CrossAttention._sliced_attention (#563)

* Fix CrossAttention._sliced_attention

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-09-19 18:07:32 +02:00
committed by GitHub
parent 8d36d5adb1
commit 84616b5de5

View File

@@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
return tensor
def forward(self, hidden_states, context=None, mask=None):
batch_size, sequence_length, dim = hidden_states.shape
batch_size, sequence_length, _ = hidden_states.shape
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)