Compare commits

...

2 Commits

Author SHA1 Message Date
patil-suraj
ea238e821b up 2024-03-18 11:47:47 +01:00
patil-suraj
b6d1d670fc up 2024-03-18 11:34:17 +01:00

View File

@@ -767,7 +767,18 @@ class AttnProcessor:
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# encoder_hidden_states = hidden_states
batch, seq, dim = hidden_states.shape
height = width = seq**0.5
# reshape to (batch, height, width, dim)
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
# reshape to (batch, dim, height, width)
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
# reshape to (batch, dim, seq)
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
# reshape to (batch, seq, dim)
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
@@ -1259,7 +1270,18 @@ class AttnProcessor2_0:
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
# encoder_hidden_states = hidden_states
batch, seq, dim = hidden_states.shape
height = width = seq**0.5
# reshape to (batch, height, width, dim)
encoder_hidden_states = hidden_states.view(batch, height, width, dim)
# reshape to (batch, dim, height, width)
encoder_hidden_states = encoder_hidden_states.permute(0, 3, 1, 2)
encoder_hidden_states = torch.nn.functional.avg_pool2d(hidden_states, kernel_size=4)
# reshape to (batch, dim, seq)
encoder_hidden_states = encoder_hidden_states.view(batch, dim, -1)
# reshape to (batch, seq, dim)
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)