Compare commits

...

16 Commits

Author SHA1 Message Date
thomasw21
c11d4b42a7 Context has its own sequence length 2022-11-22 12:26:59 +01:00
thomasw21
d5af4fd153 Woops 2022-11-22 12:21:59 +01:00
thomasw21
1f135ac219 Woops 2022-11-22 12:20:26 +01:00
thomasw21
0c49f4cf30 Woops 2022-11-22 12:19:15 +01:00
thomasw21
5d4145cfa2 Remove transpose for baddbmm 2022-11-22 12:17:11 +01:00
thomasw21
31d26872c1 Revert "Making hidden_state contiguous before applying multiple linear layers"
This reverts commit 1cd09cccf3.
2022-11-22 12:07:54 +01:00
thomasw21
1cd09cccf3 Making hidden_state contiguous before applying multiple linear layers 2022-11-22 11:55:03 +01:00
thomasw21
fa4d738cbb Revert "Save one more copy" as it's much slower on A100
This reverts commit 136f84283c.
2022-11-22 11:53:58 +01:00
thomasw21
136f84283c Save one more copy 2022-11-22 11:49:15 +01:00
thomasw21
42ba85998f scatter_ argument is not called src, but rather value 2022-11-22 01:11:18 +01:00
thomasw21
e1623e2081 Woops 2022-11-22 01:02:24 +01:00
thomasw21
fdef40ba03 Woops 2022-11-22 00:57:19 +01:00
thomasw21
fe691feb5a Remove unused import 2022-11-22 00:52:53 +01:00
thomasw21
f2ed5d8b44 black 2022-11-22 00:48:50 +01:00
thomasw21
e43244f33a Fix transpose issue 2022-11-22 00:47:22 +01:00
thomasw21
3c45926a0e WIP: some optimizations 2022-11-22 00:37:17 +01:00
6 changed files with 75 additions and 60 deletions

View File

@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
output = F.log_softmax(logits, dim=1, dtype=torch.double).float()
if not return_dict:
return (output,)
@@ -288,55 +288,60 @@ class AttentionBlock(nn.Module):
# get scores
if self.num_heads > 1:
query_states = self.transpose_for_scores(query_proj)
key_states = self.transpose_for_scores(key_proj)
value_states = self.transpose_for_scores(value_proj)
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
else:
query_states, key_states, value_states = query_proj, key_proj, value_proj
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[1],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states.transpose(-1, -2),
beta=0,
alpha=scale,
query_states = (
self.transpose_for_scores(query_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
key_states = (
self.transpose_for_scores(key_proj)
.transpose(3, 2)
.contiguous()
.view(batch * self.num_heads, self.num_head_size, height * width)
)
value_states = (
self.transpose_for_scores(value_proj)
.contiguous()
.view(batch * self.num_heads, height * width, self.num_head_size)
)
else:
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
attention_scores = torch.baddbmm(
torch.empty(
query_states.shape[0],
query_states.shape[1],
key_states.shape[2],
dtype=query_states.dtype,
device=query_states.device,
),
query_states,
key_states,
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
# compute attention output
hidden_states = torch.bmm(attention_probs, value_states)
if self.num_heads > 1:
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
# or reformulate this into a 3D problem?
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
hidden_states = torch.matmul(attention_probs, value_states)
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
hidden_states = (
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
.permute(0, 2, 1, 3)
.contiguous()
)
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
hidden_states = hidden_states.view(new_hidden_states_shape)
else:
hidden_states = torch.bmm(attention_probs, value_states)
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
hidden_states = (hidden_states + residual) / self.rescale_output_factor
hidden_states = hidden_states + residual
if self.rescale_output_factor != 1.0:
hidden_states = hidden_states / self.rescale_output_factor
return hidden_states
@@ -492,14 +497,13 @@ class CrossAttention(nn.Module):
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
return tensor
def reshape_batch_dim_to_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.heads
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
@@ -508,23 +512,33 @@ class CrossAttention(nn.Module):
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
context_sequence_length = context.shape[1]
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
query = (
self.reshape_heads_to_batch_dim(query)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
)
value = (
self.reshape_heads_to_batch_dim(value)
.permute(0, 2, 1, 3)
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
else:
@@ -538,9 +552,9 @@ class CrossAttention(nn.Module):
def _attention(self, query, key, value):
attention_scores = torch.baddbmm(
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query,
key.transpose(-1, -2),
key,
beta=0,
alpha=self.scale,
)
@@ -563,9 +577,9 @@ class CrossAttention(nn.Module):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = torch.baddbmm(
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
query[start_idx:end_idx],
key[start_idx:end_idx].transpose(-1, -2),
key[start_idx:end_idx],
beta=0,
alpha=self.scale,
)

View File

@@ -49,11 +49,14 @@ def get_timestep_embedding(
emb = scale * emb
# concat sine and cosine embeddings
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
sin = torch.sin(emb)
cos = torch.cos(emb)
# flip sine and cosine embeddings
if flip_sin_to_cos:
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
emb = torch.cat([cos, sin], dim=-1)
else:
emb = torch.cat([sin, cos], dim=-1)
# zero pad
if embedding_dim % 2 == 1:
@@ -126,7 +129,7 @@ class GaussianFourierProjection(nn.Module):
if self.log:
x = torch.log(x)
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
x_proj = x[:, None] * self.weight[None, :] * (2 * np.pi)
if self.flip_sin_to_cos:
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)

View File

@@ -476,7 +476,9 @@ class ResnetBlock2D(nn.Module):
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
output_tensor = input_tensor + hidden_states
if self.output_scale_factor != 1.0:
output_tensor = output_tensor / self.output_scale_factor
return output_tensor

View File

@@ -1054,10 +1054,7 @@ class AttnUpBlock2D(nn.Module):
self.upsamplers = None
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
for resnet, attn, res_hidden_states in zip(self.resnets, self.attentions, reversed(res_hidden_states_tuple)):
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)

View File

@@ -17,7 +17,6 @@ from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
@@ -53,9 +52,9 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
Log onehot vectors
"""
x_onehot = F.one_hot(x, num_classes)
x_onehot = x_onehot.permute(0, 2, 1)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
batch_size, vector_length = x.shape
log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
log_x.scatter_(index=x[:, None, :], value=0.0, dim=1)
return log_x

View File

@@ -1831,7 +1831,7 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return_value = F.log_softmax(logits, dim=1, dtype=torch.double).float()
return return_value
return model