Compare commits

...

55 Commits

Author SHA1 Message Date
Dhruv Nair
d80f4772ec Merge branch 'main' into mochi-quality 2024-12-17 16:55:11 +05:30
Dhruv Nair
50c5607e96 Update src/diffusers/models/transformers/transformer_mochi.py
Co-authored-by: Aryan <aryan@huggingface.co>
2024-12-17 16:16:22 +05:30
Dhruv Nair
b75db11b9d update 2024-12-16 13:59:15 +01:00
Dhruv Nair
952f6e91b0 Merge branch 'main' into mochi-quality 2024-12-16 13:58:27 +01:00
Dhruv Nair
cbbc54b050 update 2024-12-16 09:22:16 +01:00
Dhruv Nair
142169196c Merge branch 'mochi-quality' of https://github.com/huggingface/diffusers into mochi-quality 2024-12-16 09:12:03 +01:00
Dhruv Nair
2a6b82d047 update 2024-12-16 09:11:47 +01:00
Dhruv Nair
4c800e3193 Merge branch 'main' into mochi-quality 2024-12-16 08:22:03 +01:00
Aryan
ccabe5e1cc Merge branch 'main' into mochi-quality 2024-12-14 17:46:00 +05:30
Dhruv Nair
09fe7ecd26 Merge branch 'main' into mochi-quality 2024-12-08 18:13:29 +01:00
Dhruv Nair
cc7b91d27b update 2024-12-07 17:54:03 +01:00
Dhruv Nair
11ce6b8791 update 2024-12-07 08:48:14 +01:00
Dhruv Nair
3c70b54117 update 2024-12-07 04:47:59 +01:00
Dhruv Nair
bbc58926cc update 2024-12-07 04:45:42 +01:00
Dhruv Nair
c39886ac13 update 2024-11-30 11:06:08 +01:00
Dhruv Nair
7626a34362 update 2024-11-30 10:52:18 +01:00
Dhruv Nair
ae57913fbb update 2024-11-29 16:58:43 +01:00
Dhruv Nair
dc96890d7b update 2024-11-29 13:49:44 +01:00
Dhruv Nair
a29891567e Merge branch 'mochi-quality' of https://github.com/huggingface/diffusers into mochi-quality 2024-11-29 13:45:37 +01:00
Dhruv Nair
77f9d1905a update 2024-11-29 13:00:46 +01:00
Dhruv Nair
53dbc37ea6 update 2024-11-29 09:57:57 +01:00
Dhruv Nair
ba9c1850e8 update 2024-11-29 08:54:56 +01:00
Sayak Paul
b904325627 Merge branch 'main' into mochi-quality 2024-11-28 20:56:02 +05:30
Dhruv Nair
7854061ebd update 2024-11-27 07:59:46 +01:00
Dhruv Nair
2881f2f986 update 2024-11-27 07:58:19 +01:00
Dhruv Nair
7854bde901 update 2024-11-27 07:55:35 +01:00
Dhruv Nair
d759516b2d update 2024-11-27 07:51:05 +01:00
Dhruv Nair
9c5eb368c4 update 2024-11-27 07:30:18 +01:00
Dhruv Nair
6e2011aa7d update 2024-11-27 06:56:06 +01:00
Dhruv Nair
0e8f20db46 update 2024-11-27 06:27:38 +01:00
Dhruv Nair
c17cef75be update 2024-11-27 06:25:52 +01:00
Dhruv Nair
e6fe9f1a09 update 2024-11-27 06:10:24 +01:00
Dhruv Nair
0fdef41d66 update 2024-11-27 04:48:05 +01:00
Dhruv Nair
61001c8f8f update 2024-11-27 03:51:46 +01:00
Dhruv Nair
fb4e175356 update 2024-11-27 03:47:21 +01:00
Dhruv Nair
b7464e5828 update 2024-11-27 01:55:03 +01:00
Dhruv Nair
8a5d03b903 update 2024-11-26 18:58:12 +01:00
Dhruv Nair
f3fefaecad update 2024-11-26 12:30:16 +01:00
Dhruv Nair
59c9f5d9fa update 2024-11-26 10:44:52 +01:00
Dhruv Nair
883f5c8ef4 update 2024-11-26 10:30:34 +01:00
Dhruv Nair
0b09231c76 update 2024-11-26 10:02:54 +01:00
Dhruv Nair
900feadbc9 update 2024-11-26 09:12:17 +01:00
Dhruv Nair
2cfca5e0d2 update 2024-11-26 07:07:01 +01:00
Dhruv Nair
8b9d5b63ae update 2024-11-25 16:40:09 +01:00
Dhruv Nair
d99234feac update 2024-11-25 15:35:24 +01:00
Dhruv Nair
dded24364c update 2024-11-25 08:50:27 +01:00
Dhruv Nair
3ffa711db1 update 2024-11-25 08:34:33 +01:00
Dhruv Nair
66a5f59ca1 update 2024-11-25 08:25:54 +01:00
Dhruv Nair
1782d0241a update 2024-11-25 08:11:11 +01:00
Dhruv Nair
fcc59d01a9 update 2024-11-23 17:15:18 +01:00
Dhruv Nair
21b09979dc update 2024-11-22 13:21:32 +01:00
Dhruv Nair
79380ca719 update 2024-11-20 19:41:08 +01:00
Dhruv Nair
10275feacd update 2024-11-20 13:57:41 +01:00
Dhruv Nair
30dd9f6845 update 2024-11-18 17:50:51 +01:00
Dhruv Nair
27f81bd54f update 2024-11-18 17:30:24 +01:00
7 changed files with 337 additions and 159 deletions

View File

@@ -906,6 +906,177 @@ class SanaMultiscaleLinearAttention(nn.Module):
return self.processor(self, hidden_states)
class MochiAttention(nn.Module):
def __init__(
self,
query_dim: int,
added_kv_proj_dim: int,
processor: "MochiAttnProcessor2_0",
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias: bool = False,
added_proj_bias: bool = True,
out_dim: Optional[int] = None,
out_context_dim: Optional[int] = None,
out_bias: bool = True,
context_pre_only: bool = False,
eps: float = 1e-5,
):
super().__init__()
from .normalization import MochiRMSNorm
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim else query_dim
self.context_pre_only = context_pre_only
self.heads = out_dim // dim_head if out_dim is not None else heads
self.norm_q = MochiRMSNorm(dim_head, eps, True)
self.norm_k = MochiRMSNorm(dim_head, eps, True)
self.norm_added_q = MochiRMSNorm(dim_head, eps, True)
self.norm_added_k = MochiRMSNorm(dim_head, eps, True)
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.to_v = nn.Linear(query_dim, self.inner_dim, bias=bias)
self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
if self.context_pre_only is not None:
self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(nn.Dropout(dropout))
if not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
self.processor = processor
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
):
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
**kwargs,
)
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "MochiAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
total_length = sequence_length + encoder_sequence_length
batch_size, heads, _, dim = query.shape
attn_outputs = []
for idx in range(batch_size):
mask = attention_mask[idx][None, :]
valid_prompt_token_indices = torch.nonzero(mask.flatten(), as_tuple=False).flatten()
valid_encoder_query = encoder_query[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_key = encoder_key[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_encoder_value = encoder_value[idx : idx + 1, :, valid_prompt_token_indices, :]
valid_query = torch.cat([query[idx : idx + 1], valid_encoder_query], dim=2)
valid_key = torch.cat([key[idx : idx + 1], valid_encoder_key], dim=2)
valid_value = torch.cat([value[idx : idx + 1], valid_encoder_value], dim=2)
attn_output = F.scaled_dot_product_attention(
valid_query, valid_key, valid_value, dropout_p=0.0, is_causal=False
)
valid_sequence_length = attn_output.size(2)
attn_output = F.pad(attn_output, (0, 0, 0, total_length - valid_sequence_length))
attn_outputs.append(attn_output)
hidden_states = torch.cat(attn_outputs, dim=0)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class AttnProcessor:
r"""
Default processor for performing attention-related computations.
@@ -3868,94 +4039,6 @@ class LuminaAttnProcessor2_0:
return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if getattr(attn, "to_add_out", None) is not None:
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
@@ -5668,13 +5751,13 @@ AttentionProcessor = Union[
AttnProcessorNPU,
AttnProcessor2_0,
MochiVaeAttnProcessor2_0,
MochiAttnProcessor2_0,
StableAudioAttnProcessor2_0,
HunyuanAttnProcessor2_0,
FusedHunyuanAttnProcessor2_0,
PAGHunyuanAttnProcessor2_0,
PAGCFGHunyuanAttnProcessor2_0,
LuminaAttnProcessor2_0,
MochiAttnProcessor2_0,
FusedAttnProcessor2_0,
CustomDiffusionXFormersAttnProcessor,
CustomDiffusionAttnProcessor2_0,

View File

@@ -542,7 +542,6 @@ class PatchEmbed(nn.Module):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

View File

@@ -234,33 +234,6 @@ class LuminaRMSNormZero(nn.Module):
return x, gate_msa, scale_mlp, gate_mlp
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
return hidden_states, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -549,6 +522,36 @@ class RMSNorm(nn.Module):
return hidden_states
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if self.weight is not None:
hidden_states = hidden_states * self.weight
hidden_states = hidden_states.to(input_dtype)
return hidden_states
class GlobalResponseNorm(nn.Module):
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
def __init__(self, dim):

View File

@@ -23,16 +23,96 @@ from ...loaders import PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..attention_processor import MochiAttention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
from ..normalization import AdaLayerNormContinuous, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MochiModulatedRMSNorm(nn.Module):
def __init__(self, eps: float):
super().__init__()
self.eps = eps
self.norm = RMSNorm(0, eps, False)
def forward(self, hidden_states, scale=None):
hidden_states_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.norm(hidden_states)
if scale is not None:
hidden_states = hidden_states * scale
hidden_states = hidden_states.to(hidden_states_dtype)
return hidden_states
class MochiLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
eps=1e-5,
bias=True,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
self.norm = MochiModulatedRMSNorm(eps=eps)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
input_dtype = x.dtype
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
scale = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
x = self.norm(x, (1 + scale.unsqueeze(1).to(torch.float32)))
return x.to(input_dtype)
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(0, eps, False)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states_dtype = hidden_states.dtype
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states.to(torch.float32)) * (1 + scale_msa[:, None].to(torch.float32))
hidden_states = hidden_states.to(hidden_states_dtype)
return hidden_states, gate_msa, scale_mlp, gate_mlp
@maybe_allow_in_graph
class MochiTransformerBlock(nn.Module):
r"""
@@ -77,38 +157,32 @@ class MochiTransformerBlock(nn.Module):
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else:
self.norm1_context = LuminaLayerNormContinuous(
self.norm1_context = MochiLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=eps,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
self.attn1 = MochiAttention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
qk_norm=qk_norm,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=eps,
elementwise_affine=True,
eps=1e-5,
)
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm2 = MochiModulatedRMSNorm(eps=eps)
self.norm2_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = MochiModulatedRMSNorm(eps)
self.norm3_context = MochiModulatedRMSNorm(eps=eps) if not self.context_pre_only else None
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None
@@ -120,14 +194,15 @@ class MochiTransformerBlock(nn.Module):
bias=False,
)
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm4 = MochiModulatedRMSNorm(eps=eps)
self.norm4_context = MochiModulatedRMSNorm(eps=eps)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
encoder_attention_mask: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
@@ -143,22 +218,25 @@ class MochiTransformerBlock(nn.Module):
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=encoder_attention_mask,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).to(torch.float32)))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
) * torch.tanh(enc_gate_msa).unsqueeze(1)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
)
norm_encoder_hidden_states = self.norm3_context(
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).to(torch.float32))
)
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
enc_gate_mlp
).unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
)
return hidden_states, encoder_hidden_states
@@ -203,7 +281,10 @@ class MochiRoPE(nn.Module):
return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
with torch.autocast(freqs.device.type, torch.float32):
# Always run ROPE freqs computation in FP32
freqs = torch.einsum("nd,dhf->nhf", pos.to(torch.float32), freqs.to(torch.float32))
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
@@ -309,7 +390,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
)
self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
inner_dim,
inner_dim,
elementwise_affine=False,
eps=1e-6,
norm_type="layer_norm",
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
@@ -350,7 +435,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -381,6 +469,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states,
encoder_hidden_states,
temb,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
@@ -389,9 +478,9 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

View File

@@ -198,7 +198,6 @@ class LTXPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLoraLoaderMixi
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 128
)
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,

View File

@@ -221,7 +221,6 @@ class LTXImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTXVideoLo
self.default_width = 704
self.default_frames = 121
# Copied from diffusers.pipelines.mochi.pipeline_mochi.MochiPipeline._get_t5_prompt_embeds with 256->128
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,

View File

@@ -210,7 +210,6 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
self.default_height = 480
self.default_width = 848
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
@@ -233,9 +232,13 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
if prompt == "" or prompt[-1] == "":
text_input_ids = torch.zeros_like(text_input_ids, device=device)
prompt_attention_mask = torch.zeros_like(prompt_attention_mask, dtype=torch.bool, device=device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
@@ -246,7 +249,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -451,7 +454,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = randn_tensor(shape, generator=generator, device=device, dtype=torch.float32)
latents = latents.to(dtype)
return latents
@property
@@ -483,7 +487,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 19,
num_inference_steps: int = 28,
num_inference_steps: int = 64,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1,
@@ -605,7 +609,6 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
@@ -624,10 +627,6 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
@@ -642,6 +641,10 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
latents,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
@@ -676,6 +679,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
# Mochi CFG + Sampling runs in FP32
noise_pred = noise_pred.to(torch.float32)
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -683,7 +688,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents.to(torch.float32), return_dict=False)[0]
latents = latents.to(latents_dtype)
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():