mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-07 04:54:47 +08:00
Compare commits
16 Commits
sd3-t5
...
thomas/sma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c11d4b42a7 | ||
|
|
d5af4fd153 | ||
|
|
1f135ac219 | ||
|
|
0c49f4cf30 | ||
|
|
5d4145cfa2 | ||
|
|
31d26872c1 | ||
|
|
1cd09cccf3 | ||
|
|
fa4d738cbb | ||
|
|
136f84283c | ||
|
|
42ba85998f | ||
|
|
e1623e2081 | ||
|
|
fdef40ba03 | ||
|
|
fe691feb5a | ||
|
|
f2ed5d8b44 | ||
|
|
e43244f33a | ||
|
|
3c45926a0e |
@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
output = F.log_softmax(logits, dim=1, dtype=torch.double).float()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
@@ -288,55 +288,60 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
# get scores
|
||||
if self.num_heads > 1:
|
||||
query_states = self.transpose_for_scores(query_proj)
|
||||
key_states = self.transpose_for_scores(key_proj)
|
||||
value_states = self.transpose_for_scores(value_proj)
|
||||
|
||||
# TODO: is there a way to perform batched matmul (e.g. baddbmm) on 4D tensors?
|
||||
# or reformulate this into a 3D problem?
|
||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
||||
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * scale
|
||||
else:
|
||||
query_states, key_states, value_states = query_proj, key_proj, value_proj
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_states.shape[0],
|
||||
query_states.shape[1],
|
||||
key_states.shape[1],
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device,
|
||||
),
|
||||
query_states,
|
||||
key_states.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
query_states = (
|
||||
self.transpose_for_scores(query_proj)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
)
|
||||
key_states = (
|
||||
self.transpose_for_scores(key_proj)
|
||||
.transpose(3, 2)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, self.num_head_size, height * width)
|
||||
)
|
||||
value_states = (
|
||||
self.transpose_for_scores(value_proj)
|
||||
.contiguous()
|
||||
.view(batch * self.num_heads, height * width, self.num_head_size)
|
||||
)
|
||||
else:
|
||||
query_states, key_states, value_states = query_proj, key_proj.transpose(-1, -2), value_proj
|
||||
|
||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_states.shape[0],
|
||||
query_states.shape[1],
|
||||
key_states.shape[2],
|
||||
dtype=query_states.dtype,
|
||||
device=query_states.device,
|
||||
),
|
||||
query_states,
|
||||
key_states,
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
|
||||
attention_probs = torch.softmax(attention_scores, dim=-1, dtype=torch.float).type(attention_scores.dtype)
|
||||
|
||||
# compute attention output
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
if self.num_heads > 1:
|
||||
# TODO: is there a way to perform batched matmul (e.g. bmm) on 4D tensors?
|
||||
# or reformulate this into a 3D problem?
|
||||
# TODO: measure whether on MPS device it would be faster to do this matmul via einsum
|
||||
# as some matmuls can be 1.94x slower than an equivalent einsum on MPS
|
||||
# https://gist.github.com/Birch-san/cba16789ec27bb20996a4b4831b13ce0
|
||||
hidden_states = torch.matmul(attention_probs, value_states)
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3).contiguous()
|
||||
hidden_states = (
|
||||
hidden_states.view(batch, self.num_heads, height * width, self.num_head_size)
|
||||
.permute(0, 2, 1, 3)
|
||||
.contiguous()
|
||||
)
|
||||
new_hidden_states_shape = hidden_states.size()[:-2] + (self.channels,)
|
||||
hidden_states = hidden_states.view(new_hidden_states_shape)
|
||||
else:
|
||||
hidden_states = torch.bmm(attention_probs, value_states)
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
hidden_states = (hidden_states + residual) / self.rescale_output_factor
|
||||
hidden_states = hidden_states + residual
|
||||
if self.rescale_output_factor != 1.0:
|
||||
hidden_states = hidden_states / self.rescale_output_factor
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -492,14 +497,13 @@ class CrossAttention(nn.Module):
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
||||
tensor = tensor.view(batch_size, seq_len, head_size, dim // head_size)
|
||||
return tensor
|
||||
|
||||
def reshape_batch_dim_to_heads(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.heads
|
||||
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.view(batch_size // head_size, head_size, seq_len, dim)
|
||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
@@ -508,23 +512,33 @@ class CrossAttention(nn.Module):
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
context_sequence_length = context.shape[1]
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
query = (
|
||||
self.reshape_heads_to_batch_dim(query)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * self.heads, sequence_length, dim // self.heads)
|
||||
)
|
||||
value = (
|
||||
self.reshape_heads_to_batch_dim(value)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||
)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 1, 3).reshape(batch_size * self.heads, context_sequence_length, dim // self.heads)
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
key = self.reshape_heads_to_batch_dim(key).permute(0, 2, 3, 1).reshape(batch_size * self.heads, dim // self.heads, context_sequence_length)
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
else:
|
||||
@@ -538,9 +552,9 @@ class CrossAttention(nn.Module):
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
torch.empty(query.shape[0], query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||
query,
|
||||
key.transpose(-1, -2),
|
||||
key,
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
@@ -563,9 +577,9 @@ class CrossAttention(nn.Module):
|
||||
start_idx = i * slice_size
|
||||
end_idx = (i + 1) * slice_size
|
||||
attn_slice = torch.baddbmm(
|
||||
torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
||||
torch.empty(slice_size, query.shape[1], key.shape[2], dtype=query.dtype, device=query.device),
|
||||
query[start_idx:end_idx],
|
||||
key[start_idx:end_idx].transpose(-1, -2),
|
||||
key[start_idx:end_idx],
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
@@ -49,11 +49,14 @@ def get_timestep_embedding(
|
||||
emb = scale * emb
|
||||
|
||||
# concat sine and cosine embeddings
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
||||
sin = torch.sin(emb)
|
||||
cos = torch.cos(emb)
|
||||
|
||||
# flip sine and cosine embeddings
|
||||
if flip_sin_to_cos:
|
||||
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
||||
emb = torch.cat([cos, sin], dim=-1)
|
||||
else:
|
||||
emb = torch.cat([sin, cos], dim=-1)
|
||||
|
||||
# zero pad
|
||||
if embedding_dim % 2 == 1:
|
||||
@@ -126,7 +129,7 @@ class GaussianFourierProjection(nn.Module):
|
||||
if self.log:
|
||||
x = torch.log(x)
|
||||
|
||||
x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi
|
||||
x_proj = x[:, None] * self.weight[None, :] * (2 * np.pi)
|
||||
|
||||
if self.flip_sin_to_cos:
|
||||
out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1)
|
||||
|
||||
@@ -476,7 +476,9 @@ class ResnetBlock2D(nn.Module):
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
|
||||
output_tensor = input_tensor + hidden_states
|
||||
if self.output_scale_factor != 1.0:
|
||||
output_tensor = output_tensor / self.output_scale_factor
|
||||
|
||||
return output_tensor
|
||||
|
||||
|
||||
@@ -1054,10 +1054,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
self.upsamplers = None
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
for resnet, attn, res_hidden_states in zip(self.resnets, self.attentions, reversed(res_hidden_states_tuple)):
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
@@ -17,7 +17,6 @@ from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
@@ -53,9 +52,9 @@ def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTen
|
||||
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
|
||||
Log onehot vectors
|
||||
"""
|
||||
x_onehot = F.one_hot(x, num_classes)
|
||||
x_onehot = x_onehot.permute(0, 2, 1)
|
||||
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
|
||||
batch_size, vector_length = x.shape
|
||||
log_x = torch.full((batch_size, num_classes, vector_length), fill_value=1e-30, dtype=torch.float, device=x.device)
|
||||
log_x.scatter_(index=x[:, None, :], value=0.0, dim=1)
|
||||
return log_x
|
||||
|
||||
|
||||
|
||||
@@ -1831,7 +1831,7 @@ class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
||||
def model(sample, t, *args):
|
||||
batch_size, num_latent_pixels = sample.shape
|
||||
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
|
||||
return_value = F.log_softmax(logits.double(), dim=1).float()
|
||||
return_value = F.log_softmax(logits, dim=1, dtype=torch.double).float()
|
||||
return return_value
|
||||
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user