mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
16 Commits
integratio
...
thomas/sma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c11d4b42a7 | ||
|
|
d5af4fd153 | ||
|
|
1f135ac219 | ||
|
|
0c49f4cf30 | ||
|
|
5d4145cfa2 | ||
|
|
31d26872c1 | ||
|
|
1cd09cccf3 | ||
|
|
fa4d738cbb | ||
|
|
136f84283c | ||
|
|
42ba85998f | ||
|
|
e1623e2081 | ||
|
|
fdef40ba03 | ||
|
|
fe691feb5a | ||
|
|
f2ed5d8b44 | ||
|
|
e43244f33a | ||
|
|
3c45926a0e |
@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||||||
logits = logits.permute(0, 2, 1)
|
logits = logits.permute(0, 2, 1)
|
||||||
|
|
||||||
# log(p(x_0))
|
# log(p(x_0))
|
||||||
output = F.log_softmax(logits.double(), dim=1).float()
|
output = F.log_softmax(logits, dim=1, dtype=torch.double).float()
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return (output,)
|
return (output,)
|
||||||
@@ -288,55 +288,60 @@ class AttentionBlock(nn.Module):
|
|||||||
|
|
||||||
# get scores
|
# get scores
|
||||||
if self.num_heads > 1:
|
if self.num_heads > 1:
|
||||||
query_states = self.transpose_for_scores(query_proj)
|
query_states = (
|
||||||
key_states = self.transpose_for_scores(key_proj)
|
self.transpose_for_scores(query_proj)
|
||||||
value_states = self.transpose_for_scores(value_proj)
|
.contiguous()
|
||||||
|
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||||
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
|
|
||||||
# or reformulate this into a 3D problem?
|
|
||||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
|
||||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
|
||||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
|
||||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
|
|
||||||
else:
|
|
||||||
query_states, key_states, value_states = query_proj, key_proj, value_proj
|
|
||||||
|
|
||||||
attention_scores = torch.baddbmm(
|
|
||||||
torch.empty(
|
|
||||||
query_states.shape[0],
|
|
||||||
query_states.shape[1],
|
|
||||||
key_states.shape[1],
|
|
||||||
dtype=query_states.dtype,
|
|
||||||
device=query_states.device,
|
|
||||||
),
|
|
||||||
query_states,
|
|
||||||
key_states.transpose(-1, -2),
|
|
||||||
beta=0,
|
|
||||||
alpha=scale,
|
|
||||||
)
|
)
|
||||||
|
key_states = (
|
||||||
|
self.transpose_for_scores(key_proj)
|
||||||
|
.transpose(3, 2)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch * self.num_heads, self.num_head_size, height * width)
|
||||||
|
)
|
||||||
|
value_states = (
|
||||||
|
self.transpose_for_scores(value_proj)
|
||||||
|
.contiguous()
|
||||||
|
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
|
||||||
|
|
||||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
attention_scores = torch.baddbmm(
|
||||||
|
torch.empty(
|
||||||
|
query_states.shape[0],
|
||||||
|
query_states.shape[1],
|
||||||
|
key_states.shape[2],
|
||||||
|
dtype=query_states.dtype,
|
||||||
|
device=query_states.device,
|
||||||
|
),
|
||||||
|
query_states,
|
||||||
|
key_states,
|
||||||
|
beta=0,
|
||||||
|
alpha=scale,
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
|
||||||
|
|
||||||
# compute attention output
|
# compute attention output
|
||||||
|
hidden_states = torch.bmm(attention_probs, value_states)
|
||||||
if self.num_heads > 1:
|
if self.num_heads > 1:
|
||||||
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
|
hidden_states = (
|
||||||
# or reformulate this into a 3D problem?
|
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
|
||||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
.permute(0, 2, 1, 3)
|
||||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
.contiguous()
|
||||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
)
|
||||||
hidden_states = torch.matmul(attention_probs, value_states)
|
|
||||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
|
||||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||||
else:
|
|
||||||
hidden_states = torch.bmm(attention_probs, value_states)
|
|
||||||
|
|
||||||
# compute next hidden_states
|
# compute next hidden_states
|
||||||
hidden_states = self.proj_attn(hidden_states)
|
hidden_states = self.proj_attn(hidden_states)
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||||
|
|
||||||
# res connect and rescale
|
# res connect and rescale
|
||||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
hidden_states = hidden_states + residual
|
||||||
|
if self.rescale_output_factor != 1.0:
|
||||||
|
hidden_states = hidden_states / self.rescale_output_factor
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -492,14 +497,13 @@ class CrossAttention(nn.Module):
|
|||||||
def reshape_heads_to_batch_dim(self, tensor):
|
def reshape_heads_to_batch_dim(self, tensor):
|
||||||
batch_size, seq_len, dim = tensor.shape
|
batch_size, seq_len, dim = tensor.shape
|
||||||
head_size = self.heads
|
head_size = self.heads
|
||||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def reshape_batch_dim_to_heads(self, tensor):
|
def reshape_batch_dim_to_heads(self, tensor):
|
||||||
batch_size, seq_len, dim = tensor.shape
|
batch_size, seq_len, dim = tensor.shape
|
||||||
head_size = self.heads
|
head_size = self.heads
|
||||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
@@ -508,23 +512,33 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
query = self.to_q(hidden_states)
|
query = self.to_q(hidden_states)
|
||||||
context = context if context is not None else hidden_states
|
context = context if context is not None else hidden_states
|
||||||
|
context_sequence_length = context.shape[1]
|
||||||
key = self.to_k(context)
|
key = self.to_k(context)
|
||||||
value = self.to_v(context)
|
value = self.to_v(context)
|
||||||
|
|
||||||
dim = query.shape[-1]
|
dim = query.shape[-1]
|
||||||
|
|
||||||
query = self.reshape_heads_to_batch_dim(query)
|
query = (
|
||||||
key = self.reshape_heads_to_batch_dim(key)
|
self.reshape_heads_to_batch_dim(query)
|
||||||
value = self.reshape_heads_to_batch_dim(value)
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
|
||||||
|
)
|
||||||
|
value = (
|
||||||
|
self.reshape_heads_to_batch_dim(value)
|
||||||
|
.permute(0, 2, 1, 3)
|
||||||
|
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||||
|
)
|
||||||
|
|
||||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||||
|
|
||||||
# attention, what we cannot get enough of
|
# attention, what we cannot get enough of
|
||||||
if self._use_memory_efficient_attention_xformers:
|
if self._use_memory_efficient_attention_xformers:
|
||||||
|
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
else:
|
else:
|
||||||
|
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
|
||||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||||
hidden_states = self._attention(query, key, value)
|
hidden_states = self._attention(query, key, value)
|
||||||
else:
|
else:
|
||||||
@@ -538,9 +552,9 @@ class CrossAttention(nn.Module):
|
|||||||
|
|
||||||
def _attention(self, query, key, value):
|
def _attention(self, query, key, value):
|
||||||
attention_scores = torch.baddbmm(
|
attention_scores = torch.baddbmm(
|
||||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||||
query,
|
query,
|
||||||
key.transpose(-1, -2),
|
key,
|
||||||
beta=0,
|
beta=0,
|
||||||
alpha=self.scale,
|
alpha=self.scale,
|
||||||
)
|
)
|
||||||
@@ -563,9 +577,9 @@ class CrossAttention(nn.Module):
|
|||||||
start_idx = i * slice_size
|
start_idx = i * slice_size
|
||||||
end_idx = (i + 1) * slice_size
|
end_idx = (i + 1) * slice_size
|
||||||
attn_slice = torch.baddbmm(
|
attn_slice = torch.baddbmm(
|
||||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||||
query[start_idx:end_idx],
|
query[start_idx:end_idx],
|
||||||
key[start_idx:end_idx].transpose(-1, -2),
|
key[start_idx:end_idx],
|
||||||
beta=0,
|
beta=0,
|
||||||
alpha=self.scale,
|
alpha=self.scale,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -49,11 +49,14 @@ def get_timestep_embedding(
|
|||||||
emb = scale * emb
|
emb = scale * emb
|
||||||
|
|
||||||
# concat sine and cosine embeddings
|
# concat sine and cosine embeddings
|
||||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
sin = torch.sin(emb)
|
||||||
|
cos = torch.cos(emb)
|
||||||
|
|
||||||
# flip sine and cosine embeddings
|
# flip sine and cosine embeddings
|
||||||
if flip_sin_to_cos:
|
if flip_sin_to_cos:
|
||||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
emb = torch.cat([cos, sin], dim=-1)
|
||||||
|
else:
|
||||||
|
emb = torch.cat([sin, cos], dim=-1)
|
||||||
|
|
||||||
# zero pad
|
# zero pad
|
||||||
if embedding_dim % 2 == 1:
|
if embedding_dim % 2 == 1:
|
||||||
@@ -126,7 +129,7 @@ class GaussianFourierProjection(nn.Module):
|
|||||||
if self.log:
|
if self.log:
|
||||||
x = torch.log(x)
|
x = torch.log(x)
|
||||||
|
|
||||||
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
x_proj = x[:, None] * self.weight[None, :] * (2 * np.pi)
|
||||||
|
|
||||||
if self.flip_sin_to_cos:
|
if self.flip_sin_to_cos:
|
||||||
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
||||||
|
|||||||
@@ -476,7 +476,9 @@ class ResnetBlock2D(nn.Module):
|
|||||||
if self.conv_shortcut is not None:
|
if self.conv_shortcut is not None:
|
||||||
input_tensor = self.conv_shortcut(input_tensor)
|
input_tensor = self.conv_shortcut(input_tensor)
|
||||||
|
|
||||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
output_tensor = input_tensor + hidden_states
|
||||||
|
if self.output_scale_factor != 1.0:
|
||||||
|
output_tensor = output_tensor / self.output_scale_factor
|
||||||
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
|||||||
@@ -1054,10 +1054,7 @@ class AttnUpBlock2D(nn.Module):
|
|||||||
self.upsamplers = None
|
self.upsamplers = None
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||||
for resnet, attn in zip(self.resnets, self.attentions):
|
for resnet, attn, res_hidden_states in zip(self.resnets, self.attentions, reversed(res_hidden_states_tuple)):
|
||||||
# pop res hidden states
|
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
hidden_states = resnet(hidden_states, temb)
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
@@ -53,9 +52,9 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
|
|||||||
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
|
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
|
||||||
Log onehot vectors
|
Log onehot vectors
|
||||||
"""
|
"""
|
||||||
x_onehot = F.one_hot(x, num_classes)
|
batch_size, vector_length = x.shape
|
||||||
x_onehot = x_onehot.permute(0, 2, 1)
|
log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
|
||||||
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
|
log_x.scatter_(index=x[:, None, :], value=0.0, dim=1)
|
||||||
return log_x
|
return log_x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1831,7 +1831,7 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
|||||||
def model(sample, t, *args):
|
def model(sample, t, *args):
|
||||||
batch_size, num_latent_pixels = sample.shape
|
batch_size, num_latent_pixels = sample.shape
|
||||||
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
|
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
|
||||||
return_value = F.log_softmax(logits.double(), dim=1).float()
|
return_value = F.log_softmax(logits, dim=1, dtype=torch.double).float()
|
||||||
return return_value
|
return return_value
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|||||||
Reference in New Issue
Block a user