Compare commits

...

18 Commits

Author SHA1 Message Date
Aryan
54b6b3f8be update 2024-10-25 08:59:04 +02:00
Aryan
e92ee28aa4 update 2024-10-25 08:58:46 +02:00
Aryan
40cf52fc0a add transformer docs 2024-10-24 22:00:49 +02:00
Aryan
76b697f2c4 add model test 2024-10-24 21:56:51 +02:00
Aryan
5797093a47 Merge branch 'main' into mochi-t2v 2024-10-24 21:48:51 +02:00
Aryan
0a54230729 fix remainnig bugs; transformer outputs match 2024-10-24 21:09:40 +02:00
Aryan
46f95d5cdb make style 2024-10-24 13:49:12 +02:00
Aryan
2fd2ec4025 fixes 2024-10-24 13:48:22 +02:00
Aryan
85c8734cdc fix 2024-10-24 05:47:03 +02:00
Aryan
98a4554ac6 update 2024-10-24 05:04:43 +02:00
Aryan
1e9bc91b5c fix 2024-10-24 04:02:37 +02:00
Aryan
be5bbe53e1 update 2024-10-24 03:48:31 +02:00
Aryan
c2a155714b add conversion script 2024-10-24 03:48:10 +02:00
Aryan
0e9e281ad1 fix 2024-10-24 02:40:31 +02:00
Aryan
05ebd6cd82 make style 2024-10-24 01:27:51 +02:00
Aryan
da48940b56 update transformer 2024-10-24 01:27:25 +02:00
Aryan
64275b0e66 udpate 2024-10-24 00:17:16 +02:00
Aryan
e488d09df1 update 2024-10-23 10:26:08 +02:00
15 changed files with 1969 additions and 9 deletions

View File

@@ -266,6 +266,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer

View File

@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# MochiTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
The model can be loaded with the following code snippet.
```python
from diffusers import MochiTransformer3DModel
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## MochiTransformer3DModel
[[autodoc]] MochiTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -0,0 +1,187 @@
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
# from transformers import T5EncoderModel, T5Tokenizer
from diffusers import MochiTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 256
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
# This is specific to `AdaLayerNormContinuous`:
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = load_file(ckpt_path, device="cpu")
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
old_prefix + "attn.proj_y.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp_y.w1.weight"
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp_y.w2.weight"
)
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.weight"), dim=0)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
print("Remaining Keys:", original_state_dict.keys())
return new_state_dict
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
def main(args):
if args.dtype is None:
dtype = None
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
# vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = MochiTransformer3DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
# Original checkpoint data type will be preserved
transformer = transformer.to(dtype=dtype)
# text_encoder_id = "google/t5-v1_1-xxl"
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# # Apparently, the conversion does not work anymore without this :shrug:
# for param in text_encoder.parameters():
# param.data = param.data.contiguous()
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
if __name__ == "__main__":
main(args)

View File

@@ -100,6 +100,7 @@ else:
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -579,6 +580,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,

View File

@@ -56,6 +56,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,

View File

@@ -134,14 +134,18 @@ class SwiGLU(nn.Module):
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
super().__init__()
self.flip_gate = flip_gate
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
if self.flip_gate:
hidden_states, gate = gate, hidden_states
return hidden_states * self.activation(gate)

View File

@@ -1206,6 +1206,7 @@ class FeedForward(nn.Module):
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
flip_gate: bool = False,
):
super().__init__()
if inner_dim is None:
@@ -1221,7 +1222,7 @@ class FeedForward(nn.Module):
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
act_fn = SwiGLU(dim, inner_dim, bias=bias, flip_gate=flip_gate)
self.net = nn.ModuleList([])
# project in

View File

@@ -120,6 +120,7 @@ class Attention(nn.Module):
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
@@ -142,6 +143,7 @@ class Attention(nn.Module):
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
@@ -241,7 +243,7 @@ class Attention(nn.Module):
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
@@ -1792,6 +1794,7 @@ class FluxAttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
@@ -3078,6 +3081,94 @@ class LuminaAttnProcessor2_0:
return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

View File

@@ -1302,6 +1302,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
return conditioning
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
text_embed_dim: int,
time_embed_dim: int = 256,
num_attention_heads: int = 8,
) -> None:
super().__init__()
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
self.pooler = MochiAttentionPool(
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
)
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
def forward(
self,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
hidden_dtype: Optional[torch.dtype] = None,
):
time_proj = self.time_proj(timestep)
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
caption_proj = self.caption_proj(encoder_hidden_states)
conditioning = time_emb + pooled_projections
return conditioning, caption_proj
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
@@ -1430,6 +1465,88 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token
class MochiAttentionPool(nn.Module):
def __init__(
self,
num_attention_heads: int,
embed_dim: int,
output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.output_dim = output_dim or embed_dim
self.num_attention_heads = num_attention_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
self.to_q = nn.Linear(embed_dim, embed_dim)
self.to_out = nn.Linear(embed_dim, self.output_dim)
@staticmethod
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
r"""
Args:
x (`torch.Tensor`):
Tensor of shape `(B, S, D)` of input tokens.
mask (`torch.Tensor`):
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
Returns:
`torch.Tensor`:
`(B, D)` tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_attention_heads
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:

View File

@@ -237,6 +237,33 @@ class LuminaRMSNormZero(nn.Module):
return x, gate_msa, scale_mlp, gate_mlp
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
return hidden_states, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -358,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module):
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
if norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
# linear_2
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(
embedding_dim,
out_dim,
bias=bias,
)
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,

View File

@@ -16,5 +16,6 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,323 @@
# Copyright 2024 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class MochiTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
pooled_projection_dim: int,
qk_norm: str = "rms_norm",
activation_fn: str = "swiglu",
context_pre_only: bool = True,
eps: float = 1e-6,
) -> None:
super().__init__()
self.context_pre_only = context_pre_only
self.ff_inner_dim = (4 * dim * 2) // 3
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else:
self.norm1_context = LuminaLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=eps,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
qk_norm=qk_norm,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=eps,
elementwise_affine=True,
)
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True)
self.ff_context = None
if not context_pre_only:
self.ff_context = FeedForward(
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
)
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
if not self.context_pre_only:
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
encoder_hidden_states, temb
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
) * torch.tanh(enc_gate_msa).unsqueeze(1)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(1)
return hidden_states, encoder_hidden_states
class MochiRoPE(nn.Module):
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
super().__init__()
self.target_area = base_height * base_width
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
return (edges[:-1] + edges[1:]) / 2
def _get_positions(
self,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
scale = (self.target_area / (height * width)) ** 0.5
t = torch.arange(num_frames, device=device, dtype=dtype)
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs)
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def forward(
self,
pos_frequencies: torch.Tensor,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pos = self._get_positions(num_frames, height, width, device, dtype)
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
return rope_cos, rope_sin
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 48,
pooled_projection_dim: int = 1536,
in_channels: int = 12,
out_channels: Optional[int] = None,
qk_norm: str = "rms_norm",
text_embed_dim: int = 4096,
time_embed_dim: int = 256,
activation_fn: str = "swiglu",
max_sequence_length: int = 256,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
pos_embed_type=None,
)
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
embedding_dim=inner_dim,
pooled_projection_dim=pooled_projection_dim,
text_embed_dim=text_embed_dim,
time_embed_dim=time_embed_dim,
num_attention_heads=8,
)
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
self.rope = MochiRoPE()
self.transformer_blocks = nn.ModuleList(
[
MochiTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
pooled_projection_dim=pooled_projection_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
context_pre_only=i == num_layers - 1,
)
for i in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
for i, block in enumerate(self.transformer_blocks):
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

View File

@@ -347,6 +347,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MochiTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -0,0 +1,80 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import MochiTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
}
@property
def input_shape(self):
return (4, 2, 16, 16)
@property
def output_shape(self):
return (4, 2, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"num_attention_heads": 2,
"attention_head_dim": 8,
"num_layers": 2,
"pooled_projection_dim": 16,
"in_channels": 4,
"out_channels": None,
"qk_norm": "rms_norm",
"text_embed_dim": 16,
"time_embed_dim": 4,
"activation_fn": "swiglu",
"max_sequence_length": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict