mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 12:34:13 +08:00
Fix CrossAttention._sliced_attention (#563)
* Fix CrossAttention._sliced_attention Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -249,13 +249,15 @@ class CrossAttention(nn.Module):
|
||||
return tensor
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
batch_size, sequence_length, dim = hidden_states.shape
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
Reference in New Issue
Block a user