mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
18 Commits
modular-di
...
mochi-t2v
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54b6b3f8be | ||
|
|
e92ee28aa4 | ||
|
|
40cf52fc0a | ||
|
|
76b697f2c4 | ||
|
|
5797093a47 | ||
|
|
0a54230729 | ||
|
|
46f95d5cdb | ||
|
|
2fd2ec4025 | ||
|
|
85c8734cdc | ||
|
|
98a4554ac6 | ||
|
|
1e9bc91b5c | ||
|
|
be5bbe53e1 | ||
|
|
c2a155714b | ||
|
|
0e9e281ad1 | ||
|
|
05ebd6cd82 | ||
|
|
da48940b56 | ||
|
|
64275b0e66 | ||
|
|
e488d09df1 |
@@ -266,6 +266,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/lumina_nextdit2d
|
||||
title: LuminaNextDiT2DModel
|
||||
- local: api/models/mochi_transformer3d
|
||||
title: MochiTransformer3DModel
|
||||
- local: api/models/pixart_transformer2d
|
||||
title: PixArtTransformer2DModel
|
||||
- local: api/models/prior_transformer
|
||||
|
||||
30
docs/source/en/api/models/mochi_transformer3d.md
Normal file
30
docs/source/en/api/models/mochi_transformer3d.md
Normal file
@@ -0,0 +1,30 @@
|
||||
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# MochiTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import MochiTransformer3DModel
|
||||
|
||||
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
|
||||
```
|
||||
|
||||
## MochiTransformer3DModel
|
||||
|
||||
[[autodoc]] MochiTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
187
scripts/convert_mochi_to_diffusers.py
Normal file
187
scripts/convert_mochi_to_diffusers.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import argparse
|
||||
from contextlib import nullcontext
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from safetensors.torch import load_file
|
||||
|
||||
# from transformers import T5EncoderModel, T5Tokenizer
|
||||
from diffusers import MochiTransformer3DModel
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
TOKENIZER_MAX_LENGTH = 256
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
|
||||
# parser.add_argument("--vae_checkpoint_path", default=None, type=str)
|
||||
parser.add_argument("--output_path", required=True, type=str)
|
||||
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
|
||||
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
|
||||
parser.add_argument("--dtype", type=str, default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
# This is specific to `AdaLayerNormContinuous`:
|
||||
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
|
||||
def swap_scale_shift(weight, dim):
|
||||
shift, scale = weight.chunk(2, dim=0)
|
||||
new_weight = torch.cat([scale, shift], dim=0)
|
||||
return new_weight
|
||||
|
||||
|
||||
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
|
||||
original_state_dict = load_file(ckpt_path, device="cpu")
|
||||
new_state_dict = {}
|
||||
|
||||
# Convert patch_embed
|
||||
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
|
||||
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
# Convert time_embed
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
|
||||
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
|
||||
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
|
||||
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
|
||||
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
|
||||
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
|
||||
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
|
||||
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
|
||||
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
|
||||
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
|
||||
|
||||
# Convert transformer blocks
|
||||
num_layers = 48
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
old_prefix = f"blocks.{i}."
|
||||
|
||||
# norm1
|
||||
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
|
||||
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
else:
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
|
||||
old_prefix + "mod_y.bias"
|
||||
)
|
||||
|
||||
# Visual attention
|
||||
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.q_norm_x.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.k_norm_x.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.proj_x.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
|
||||
|
||||
# Context attention
|
||||
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
|
||||
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
|
||||
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
|
||||
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.q_norm_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.k_norm_y.weight"
|
||||
)
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
|
||||
old_prefix + "attn.proj_y.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
|
||||
old_prefix + "attn.proj_y.bias"
|
||||
)
|
||||
|
||||
# MLP
|
||||
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
|
||||
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
|
||||
if i < num_layers - 1:
|
||||
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp_y.w1.weight"
|
||||
)
|
||||
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
|
||||
old_prefix + "mlp_y.w2.weight"
|
||||
)
|
||||
|
||||
# Output layers
|
||||
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.weight"), dim=0)
|
||||
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
|
||||
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
|
||||
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
|
||||
|
||||
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
|
||||
|
||||
print("Remaining Keys:", original_state_dict.keys())
|
||||
|
||||
return new_state_dict
|
||||
|
||||
|
||||
# def convert_mochi_vae_checkpoint_to_diffusers(ckpt_path, vae_config):
|
||||
# original_state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
|
||||
# return convert_ldm_vae_checkpoint(original_state_dict, vae_config)
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.dtype is None:
|
||||
dtype = None
|
||||
if args.dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif args.dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
elif args.dtype == "fp32":
|
||||
dtype = torch.float32
|
||||
else:
|
||||
raise ValueError(f"Unsupported dtype: {args.dtype}")
|
||||
|
||||
transformer = None
|
||||
# vae = None
|
||||
|
||||
if args.transformer_checkpoint_path is not None:
|
||||
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
|
||||
args.transformer_checkpoint_path
|
||||
)
|
||||
transformer = MochiTransformer3DModel()
|
||||
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
|
||||
if dtype is not None:
|
||||
# Original checkpoint data type will be preserved
|
||||
transformer = transformer.to(dtype=dtype)
|
||||
|
||||
# text_encoder_id = "google/t5-v1_1-xxl"
|
||||
# tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
|
||||
# text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
|
||||
|
||||
# # Apparently, the conversion does not work anymore without this :shrug:
|
||||
# for param in text_encoder.parameters():
|
||||
# param.data = param.data.contiguous()
|
||||
|
||||
transformer.save_pretrained("/raid/aryan/mochi-diffusers", subfolder="transformer")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main(args)
|
||||
@@ -100,6 +100,7 @@ else:
|
||||
"Kandinsky3UNet",
|
||||
"LatteTransformer3DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
"MochiTransformer3DModel",
|
||||
"ModelMixin",
|
||||
"MotionAdapter",
|
||||
"MultiAdapter",
|
||||
@@ -579,6 +580,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky3UNet,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
ModelMixin,
|
||||
MotionAdapter,
|
||||
MultiAdapter,
|
||||
|
||||
@@ -56,6 +56,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
|
||||
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
|
||||
@@ -106,6 +107,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanDiT2DModel,
|
||||
LatteTransformer3DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -134,14 +134,18 @@ class SwiGLU(nn.Module):
|
||||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
|
||||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
|
||||
super().__init__()
|
||||
self.flip_gate = flip_gate
|
||||
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
|
||||
self.activation = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states, gate = hidden_states.chunk(2, dim=-1)
|
||||
if self.flip_gate:
|
||||
hidden_states, gate = gate, hidden_states
|
||||
return hidden_states * self.activation(gate)
|
||||
|
||||
|
||||
|
||||
@@ -1206,6 +1206,7 @@ class FeedForward(nn.Module):
|
||||
final_dropout: bool = False,
|
||||
inner_dim=None,
|
||||
bias: bool = True,
|
||||
flip_gate: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
if inner_dim is None:
|
||||
@@ -1221,7 +1222,7 @@ class FeedForward(nn.Module):
|
||||
elif activation_fn == "geglu-approximate":
|
||||
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
|
||||
elif activation_fn == "swiglu":
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias)
|
||||
act_fn = SwiGLU(dim, inner_dim, bias=bias, flip_gate=flip_gate)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
|
||||
@@ -120,6 +120,7 @@ class Attention(nn.Module):
|
||||
_from_deprecated_attn_block: bool = False,
|
||||
processor: Optional["AttnProcessor"] = None,
|
||||
out_dim: int = None,
|
||||
out_context_dim: int = None,
|
||||
context_pre_only=None,
|
||||
pre_only=False,
|
||||
elementwise_affine: bool = True,
|
||||
@@ -142,6 +143,7 @@ class Attention(nn.Module):
|
||||
self.dropout = dropout
|
||||
self.fused_projections = False
|
||||
self.out_dim = out_dim if out_dim is not None else query_dim
|
||||
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
self.pre_only = pre_only
|
||||
|
||||
@@ -241,7 +243,7 @@ class Attention(nn.Module):
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
|
||||
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
|
||||
|
||||
if qk_norm is not None and added_kv_proj_dim is not None:
|
||||
if qk_norm == "fp32_layer_norm":
|
||||
@@ -1792,6 +1794,7 @@ class FluxAttnProcessor2_0:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
@@ -3078,6 +3081,94 @@ class LuminaAttnProcessor2_0:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiAttnProcessor2_0:
|
||||
"""Attention processor used in Mochi."""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
encoder_query = attn.add_q_proj(encoder_hidden_states)
|
||||
encoder_key = attn.add_k_proj(encoder_hidden_states)
|
||||
encoder_value = attn.add_v_proj(encoder_hidden_states)
|
||||
|
||||
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if attn.norm_added_q is not None:
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
if attn.norm_added_k is not None:
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
|
||||
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
||||
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
||||
|
||||
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
||||
|
||||
query = apply_rotary_emb(query, *image_rotary_emb)
|
||||
key = apply_rotary_emb(key, *image_rotary_emb)
|
||||
|
||||
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
||||
encoder_query, encoder_key, encoder_value = (
|
||||
encoder_query.transpose(1, 2),
|
||||
encoder_key.transpose(1, 2),
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=2)
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if hasattr(attn, "to_add_out"):
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class FusedAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses
|
||||
|
||||
@@ -1302,6 +1302,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
text_embed_dim: int,
|
||||
time_embed_dim: int = 256,
|
||||
num_attention_heads: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
|
||||
self.pooler = MochiAttentionPool(
|
||||
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
|
||||
)
|
||||
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
hidden_dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
time_proj = self.time_proj(timestep)
|
||||
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
|
||||
|
||||
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
|
||||
caption_proj = self.caption_proj(encoder_hidden_states)
|
||||
|
||||
conditioning = time_emb + pooled_projections
|
||||
return conditioning, caption_proj
|
||||
|
||||
|
||||
class TextTimeEmbedding(nn.Module):
|
||||
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
||||
super().__init__()
|
||||
@@ -1430,6 +1465,88 @@ class AttentionPooling(nn.Module):
|
||||
return a[:, 0, :] # cls_token
|
||||
|
||||
|
||||
class MochiAttentionPool(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int,
|
||||
embed_dim: int,
|
||||
output_dim: Optional[int] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.output_dim = output_dim or embed_dim
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
|
||||
self.to_q = nn.Linear(embed_dim, embed_dim)
|
||||
self.to_out = nn.Linear(embed_dim, self.output_dim)
|
||||
|
||||
@staticmethod
|
||||
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
|
||||
"""
|
||||
Pool tokens in x using mask.
|
||||
|
||||
NOTE: We assume x does not require gradients.
|
||||
|
||||
Args:
|
||||
x: (B, L, D) tensor of tokens.
|
||||
mask: (B, L) boolean tensor indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
pooled: (B, D) tensor of pooled tokens.
|
||||
"""
|
||||
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
|
||||
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
|
||||
mask = mask[:, :, None].to(dtype=x.dtype)
|
||||
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
|
||||
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
|
||||
return pooled
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
|
||||
r"""
|
||||
Args:
|
||||
x (`torch.Tensor`):
|
||||
Tensor of shape `(B, S, D)` of input tokens.
|
||||
mask (`torch.Tensor`):
|
||||
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
`(B, D)` tensor of pooled tokens.
|
||||
"""
|
||||
D = x.size(2)
|
||||
|
||||
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
|
||||
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
|
||||
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
|
||||
|
||||
# Average non-padding token features. These will be used as the query.
|
||||
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
|
||||
|
||||
# Concat pooled features to input sequence.
|
||||
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
|
||||
|
||||
# Compute queries, keys, values. Only the mean token is used to create a query.
|
||||
kv = self.to_kv(x) # (B, L+1, 2 * D)
|
||||
q = self.to_q(x[:, 0]) # (B, D)
|
||||
|
||||
# Extract heads.
|
||||
head_dim = D // self.num_attention_heads
|
||||
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
|
||||
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
|
||||
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
|
||||
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
|
||||
q = q.unsqueeze(2) # (B, H, 1, head_dim)
|
||||
|
||||
# Compute attention.
|
||||
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
|
||||
|
||||
# Concatenate heads and run output.
|
||||
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
|
||||
x = self.to_out(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_fourier_embeds_from_boundingbox(embed_dim, box):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -237,6 +237,33 @@ class LuminaRMSNormZero(nn.Module):
|
||||
return x, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class MochiRMSNormZero(nn.Module):
|
||||
r"""
|
||||
Adaptive RMS Norm used in Mochi.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, hidden_dim)
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, emb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
|
||||
|
||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
class AdaLayerNormSingle(nn.Module):
|
||||
r"""
|
||||
Norm layer adaptive layer norm single (adaLN-single).
|
||||
@@ -358,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module):
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
if norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
# linear_2
|
||||
|
||||
self.linear_2 = None
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(
|
||||
embedding_dim,
|
||||
out_dim,
|
||||
bias=bias,
|
||||
)
|
||||
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -16,5 +16,6 @@ if is_torch_available():
|
||||
from .transformer_2d import Transformer2DModel
|
||||
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_temporal import TransformerTemporalModel
|
||||
|
||||
323
src/diffusers/models/transformers/transformer_mochi.py
Normal file
323
src/diffusers/models/transformers/transformer_mochi.py
Normal file
@@ -0,0 +1,323 @@
|
||||
# Copyright 2024 The Genmo team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention, MochiAttnProcessor2_0
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
pooled_projection_dim: int,
|
||||
qk_norm: str = "rms_norm",
|
||||
activation_fn: str = "swiglu",
|
||||
context_pre_only: bool = True,
|
||||
eps: float = 1e-6,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
self.ff_inner_dim = (4 * dim * 2) // 3
|
||||
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
|
||||
|
||||
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
if not context_pre_only:
|
||||
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
else:
|
||||
self.norm1_context = LuminaLayerNormContinuous(
|
||||
embedding_dim=pooled_projection_dim,
|
||||
conditioning_embedding_dim=dim,
|
||||
eps=eps,
|
||||
elementwise_affine=False,
|
||||
norm_type="rms_norm",
|
||||
out_dim=None,
|
||||
)
|
||||
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
bias=False,
|
||||
qk_norm=qk_norm,
|
||||
added_kv_proj_dim=pooled_projection_dim,
|
||||
added_proj_bias=False,
|
||||
out_dim=dim,
|
||||
out_context_dim=pooled_projection_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
processor=MochiAttnProcessor2_0(),
|
||||
eps=eps,
|
||||
elementwise_affine=True,
|
||||
)
|
||||
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True)
|
||||
self.ff_context = None
|
||||
if not context_pre_only:
|
||||
self.ff_context = FeedForward(
|
||||
pooled_projection_dim, inner_dim=self.ff_context_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
|
||||
)
|
||||
|
||||
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
if not self.context_pre_only:
|
||||
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
|
||||
encoder_hidden_states, temb
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states
|
||||
) * torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class MochiRoPE(nn.Module):
|
||||
def __init__(self, base_height: int = 192, base_width: int = 192, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.target_area = base_height * base_width
|
||||
|
||||
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
|
||||
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
|
||||
return (edges[:-1] + edges[1:]) / 2
|
||||
|
||||
def _get_positions(
|
||||
self,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
scale = (self.target_area / (height * width)) ** 0.5
|
||||
|
||||
t = torch.arange(num_frames, device=device, dtype=dtype)
|
||||
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
|
||||
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
|
||||
|
||||
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
|
||||
|
||||
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
|
||||
return positions
|
||||
|
||||
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
freqs = torch.einsum("nd,dhf->nhf", pos, freqs)
|
||||
freqs_cos = torch.cos(freqs)
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pos_frequencies: torch.Tensor,
|
||||
num_frames: int,
|
||||
height: int,
|
||||
width: int,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
pos = self._get_positions(num_frames, height, width, device, dtype)
|
||||
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
|
||||
return rope_cos, rope_sin
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
num_attention_heads: int = 24,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 48,
|
||||
pooled_projection_dim: int = 1536,
|
||||
in_channels: int = 12,
|
||||
out_channels: Optional[int] = None,
|
||||
qk_norm: str = "rms_norm",
|
||||
text_embed_dim: int = 4096,
|
||||
time_embed_dim: int = 256,
|
||||
activation_fn: str = "swiglu",
|
||||
max_sequence_length: int = 256,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
pos_embed_type=None,
|
||||
)
|
||||
|
||||
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
|
||||
embedding_dim=inner_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
text_embed_dim=text_embed_dim,
|
||||
time_embed_dim=time_embed_dim,
|
||||
num_attention_heads=8,
|
||||
)
|
||||
|
||||
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
|
||||
self.rope = MochiRoPE()
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
MochiTransformerBlock(
|
||||
dim=inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
context_pre_only=i == num_layers - 1,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
p = self.config.patch_size
|
||||
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
temb, encoder_hidden_states = self.time_embed(
|
||||
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.patch_embed(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
|
||||
|
||||
image_rotary_emb = self.rope(
|
||||
self.pos_frequencies,
|
||||
num_frames,
|
||||
post_patch_height,
|
||||
post_patch_width,
|
||||
device=hidden_states.device,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
if self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
||||
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
**ckpt_kwargs,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
|
||||
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
|
||||
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
1077
src/diffusers/models/transformers/transformer_mochi_original.py
Normal file
1077
src/diffusers/models/transformers/transformer_mochi_original.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -347,6 +347,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class MochiTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ModelMixin(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
80
tests/models/transformers/test_models_transformer_mochi.py
Normal file
80
tests/models/transformers/test_models_transformer_mochi.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import MochiTransformer3DModel
|
||||
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = MochiTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 4
|
||||
num_frames = 2
|
||||
height = 16
|
||||
width = 16
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 2, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"num_layers": 2,
|
||||
"pooled_projection_dim": 16,
|
||||
"in_channels": 4,
|
||||
"out_channels": None,
|
||||
"qk_norm": "rms_norm",
|
||||
"text_embed_dim": 16,
|
||||
"time_embed_dim": 4,
|
||||
"activation_fn": "swiglu",
|
||||
"max_sequence_length": 16,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
Reference in New Issue
Block a user