Compare commits

...

42 Commits

Author SHA1 Message Date
yiyixuxu
1069d210e1 test 2024-10-26 06:15:44 +02:00
Aryan
303b47cc53 Merge branch 'main' into mochi 2024-10-26 00:46:25 +02:00
Aryan
59258445aa fix pipeline num_frames 2024-10-26 00:45:47 +02:00
Aryan
5d093fed7f invert sigmas in scheduler; fix pipeline 2024-10-25 13:44:50 +02:00
Aryan
cae2801bfb make style 2024-10-25 11:58:12 +02:00
Aryan
72741ec503 pipeline fixes 2024-10-25 11:57:37 +02:00
Aryan
237e079e5e fix docs 2024-10-25 11:23:16 +02:00
Aryan
fb2ede03eb Update src/diffusers/pipelines/mochi/pipeline_mochi.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-10-25 14:01:31 +05:30
Aryan
8e11f34a74 Update src/diffusers/pipelines/mochi/pipeline_mochi.py
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2024-10-25 14:01:22 +05:30
Aryan
5f43c6a41f docs 2024-10-25 10:27:28 +02:00
Aryan
2798ed4c7a update conversion script 2024-10-25 09:59:35 +02:00
Aryan
d41198cbdf update inits 2024-10-25 09:16:41 +02:00
Aryan
c916ae59ea make style 2024-10-25 09:11:14 +02:00
Aryan
ba9f13fa90 remove original implementation 2024-10-25 09:10:55 +02:00
Aryan
723329d1f2 Merge branch 'mochi-t2v-pipeline' into mochi 2024-10-25 09:10:17 +02:00
Aryan
3e569cb644 Merge branch 'mochi-vae' into mochi 2024-10-25 09:06:02 +02:00
Aryan
a7372bd9b3 mochi transformer 2024-10-25 09:00:50 +02:00
Dhruv Nair
6552653f11 update 2024-10-25 08:31:59 +02:00
yiyixuxu
7c55ef5000 up 2024-10-25 04:24:51 +02:00
yiyixuxu
0b76fea5dd up 2024-10-25 04:14:23 +02:00
yiyixuxu
e9e92d043d up 2024-10-25 02:02:46 +02:00
yiyixuxu
0a6189eb95 add 2024-10-25 01:02:46 +02:00
Dhruv Nair
969c3aba88 update 2024-10-24 22:50:54 +02:00
Dhruv Nair
8700d64d62 update 2024-10-24 22:44:54 +02:00
yiyixuxu
85a9825449 init 2024-10-24 18:12:44 +02:00
Dhruv Nair
ebcbad2f38 update 2024-10-24 18:07:43 +02:00
Dhruv Nair
44987ad98c update 2024-10-24 16:31:10 +02:00
Dhruv Nair
c12ce7d32b Merge branch 'mochi-t2v' into mochi-t2v-pipeline 2024-10-24 14:27:03 +02:00
Dhruv Nair
275041d21e update 2024-10-24 14:26:23 +02:00
Aryan
46f95d5cdb make style 2024-10-24 13:49:12 +02:00
Aryan
2fd2ec4025 fixes 2024-10-24 13:48:22 +02:00
Dhruv Nair
ccc1b36b09 update 2024-10-24 09:40:10 +02:00
Aryan
85c8734cdc fix 2024-10-24 05:47:03 +02:00
Aryan
98a4554ac6 update 2024-10-24 05:04:43 +02:00
Aryan
1e9bc91b5c fix 2024-10-24 04:02:37 +02:00
Aryan
be5bbe53e1 update 2024-10-24 03:48:31 +02:00
Aryan
c2a155714b add conversion script 2024-10-24 03:48:10 +02:00
Aryan
0e9e281ad1 fix 2024-10-24 02:40:31 +02:00
Aryan
05ebd6cd82 make style 2024-10-24 01:27:51 +02:00
Aryan
da48940b56 update transformer 2024-10-24 01:27:25 +02:00
Aryan
64275b0e66 udpate 2024-10-24 00:17:16 +02:00
Aryan
e488d09df1 update 2024-10-23 10:26:08 +02:00
24 changed files with 2671 additions and 11 deletions

View File

@@ -268,6 +268,8 @@
title: LatteTransformer3DModel
- local: api/models/lumina_nextdit2d
title: LuminaNextDiT2DModel
- local: api/models/mochi_transformer3d
title: MochiTransformer3DModel
- local: api/models/pixart_transformer2d
title: PixArtTransformer2DModel
- local: api/models/prior_transformer
@@ -302,6 +304,8 @@
title: AutoencoderKL
- local: api/models/autoencoderkl_cogvideox
title: AutoencoderKLCogVideoX
- local: api/models/autoencoderkl_mochi
title: AutoencoderKLMochi
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/consistency_decoder_vae
@@ -394,6 +398,8 @@
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/mochi
title: Mochi
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/musicldm

View File

@@ -0,0 +1,32 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLMochi
The 3D variational autoencoder (VAE) model with KL loss used in [Mochi](https://github.com/genmoai/models) was introduced in [Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Tsinghua University & ZhipuAI.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLMochi
vae = AutoencoderKLMochi.from_pretrained("genmo/mochi-1-preview", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLMochi
[[autodoc]] AutoencoderKLMochi
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,30 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# MochiTransformer3DModel
A Diffusion Transformer model for 3D video-like data was introduced in [Mochi-1 Preview](https://huggingface.co/genmo/mochi-1-preview) by Genmo.
The model can be loaded with the following code snippet.
```python
from diffusers import MochiTransformer3DModel
vae = MochiTransformer3DModel.from_pretrained("genmo/mochi-1-preview", subfolder="transformer", torch_dtype=torch.float16).to("cuda")
```
## MochiTransformer3DModel
[[autodoc]] MochiTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -0,0 +1,36 @@
<!-- Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# Mochi
[Mochi 1 Preview](https://huggingface.co/genmo/mochi-1-preview) from Genmo.
*Mochi 1 preview is an open state-of-the-art video generation model with high-fidelity motion and strong prompt adherence in preliminary evaluation. This model dramatically closes the gap between closed and open video generation systems. The model is released under a permissive Apache 2.0 license.*
<Tip>
Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers.md) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading.md#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
</Tip>
## MochiPipeline
[[autodoc]] MochiPipeline
- all
- __call__
## MochiPipelineOutput
[[autodoc]] pipelines.mochi.pipeline_output.MochiPipelineOutput

View File

@@ -0,0 +1,299 @@
import argparse
from contextlib import nullcontext
import torch
from accelerate import init_empty_weights
from safetensors.torch import load_file
from transformers import T5EncoderModel, T5Tokenizer
from diffusers import AutoencoderKLMochi, FlowMatchEulerDiscreteScheduler, MochiPipeline, MochiTransformer3DModel
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
TOKENIZER_MAX_LENGTH = 256
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_checkpoint_path", default=None, type=str)
parser.add_argument("--vae_checkpoint_path", default=None, type=str)
parser.add_argument("--output_path", required=True, type=str)
parser.add_argument("--push_to_hub", action="store_true", default=False, help="Whether to push to HF Hub after saving")
parser.add_argument("--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory")
parser.add_argument("--dtype", type=str, default=None)
args = parser.parse_args()
# This is specific to `AdaLayerNormContinuous`:
# Diffusers implementation split the linear projection into the scale, shift while Mochi split it into shift, scale
def swap_scale_shift(weight, dim):
shift, scale = weight.chunk(2, dim=0)
new_weight = torch.cat([scale, shift], dim=0)
return new_weight
def convert_mochi_transformer_checkpoint_to_diffusers(ckpt_path):
original_state_dict = load_file(ckpt_path, device="cpu")
new_state_dict = {}
# Convert patch_embed
new_state_dict["patch_embed.proj.weight"] = original_state_dict.pop("x_embedder.proj.weight")
new_state_dict["patch_embed.proj.bias"] = original_state_dict.pop("x_embedder.proj.bias")
# Convert time_embed
new_state_dict["time_embed.timestep_embedder.linear_1.weight"] = original_state_dict.pop("t_embedder.mlp.0.weight")
new_state_dict["time_embed.timestep_embedder.linear_1.bias"] = original_state_dict.pop("t_embedder.mlp.0.bias")
new_state_dict["time_embed.timestep_embedder.linear_2.weight"] = original_state_dict.pop("t_embedder.mlp.2.weight")
new_state_dict["time_embed.timestep_embedder.linear_2.bias"] = original_state_dict.pop("t_embedder.mlp.2.bias")
new_state_dict["time_embed.pooler.to_kv.weight"] = original_state_dict.pop("t5_y_embedder.to_kv.weight")
new_state_dict["time_embed.pooler.to_kv.bias"] = original_state_dict.pop("t5_y_embedder.to_kv.bias")
new_state_dict["time_embed.pooler.to_q.weight"] = original_state_dict.pop("t5_y_embedder.to_q.weight")
new_state_dict["time_embed.pooler.to_q.bias"] = original_state_dict.pop("t5_y_embedder.to_q.bias")
new_state_dict["time_embed.pooler.to_out.weight"] = original_state_dict.pop("t5_y_embedder.to_out.weight")
new_state_dict["time_embed.pooler.to_out.bias"] = original_state_dict.pop("t5_y_embedder.to_out.bias")
new_state_dict["time_embed.caption_proj.weight"] = original_state_dict.pop("t5_yproj.weight")
new_state_dict["time_embed.caption_proj.bias"] = original_state_dict.pop("t5_yproj.bias")
# Convert transformer blocks
num_layers = 48
for i in range(num_layers):
block_prefix = f"transformer_blocks.{i}."
old_prefix = f"blocks.{i}."
# norm1
new_state_dict[block_prefix + "norm1.linear.weight"] = original_state_dict.pop(old_prefix + "mod_x.weight")
new_state_dict[block_prefix + "norm1.linear.bias"] = original_state_dict.pop(old_prefix + "mod_x.bias")
if i < num_layers - 1:
new_state_dict[block_prefix + "norm1_context.linear.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
else:
new_state_dict[block_prefix + "norm1_context.linear_1.weight"] = original_state_dict.pop(
old_prefix + "mod_y.weight"
)
new_state_dict[block_prefix + "norm1_context.linear_1.bias"] = original_state_dict.pop(
old_prefix + "mod_y.bias"
)
# Visual attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_x.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.to_q.weight"] = q
new_state_dict[block_prefix + "attn1.to_k.weight"] = k
new_state_dict[block_prefix + "attn1.to_v.weight"] = v
new_state_dict[block_prefix + "attn1.norm_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.norm_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_x.weight"
)
new_state_dict[block_prefix + "attn1.to_out.0.bias"] = original_state_dict.pop(old_prefix + "attn.proj_x.bias")
# Context attention
qkv_weight = original_state_dict.pop(old_prefix + "attn.qkv_y.weight")
q, k, v = qkv_weight.chunk(3, dim=0)
new_state_dict[block_prefix + "attn1.add_q_proj.weight"] = q
new_state_dict[block_prefix + "attn1.add_k_proj.weight"] = k
new_state_dict[block_prefix + "attn1.add_v_proj.weight"] = v
new_state_dict[block_prefix + "attn1.norm_added_q.weight"] = original_state_dict.pop(
old_prefix + "attn.q_norm_y.weight"
)
new_state_dict[block_prefix + "attn1.norm_added_k.weight"] = original_state_dict.pop(
old_prefix + "attn.k_norm_y.weight"
)
if i < num_layers - 1:
new_state_dict[block_prefix + "attn1.to_add_out.weight"] = original_state_dict.pop(
old_prefix + "attn.proj_y.weight"
)
new_state_dict[block_prefix + "attn1.to_add_out.bias"] = original_state_dict.pop(
old_prefix + "attn.proj_y.bias"
)
# MLP
new_state_dict[block_prefix + "ff.net.0.proj.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w1.weight")
new_state_dict[block_prefix + "ff.net.2.weight"] = original_state_dict.pop(old_prefix + "mlp_x.w2.weight")
if i < num_layers - 1:
new_state_dict[block_prefix + "ff_context.net.0.proj.weight"] = original_state_dict.pop(
old_prefix + "mlp_y.w1.weight"
)
new_state_dict[block_prefix + "ff_context.net.2.weight"] = original_state_dict.pop(
old_prefix + "mlp_y.w2.weight"
)
# Output layers
new_state_dict["norm_out.linear.weight"] = swap_scale_shift(
original_state_dict.pop("final_layer.mod.weight"), dim=0
)
new_state_dict["norm_out.linear.bias"] = swap_scale_shift(original_state_dict.pop("final_layer.mod.bias"), dim=0)
new_state_dict["proj_out.weight"] = original_state_dict.pop("final_layer.linear.weight")
new_state_dict["proj_out.bias"] = original_state_dict.pop("final_layer.linear.bias")
new_state_dict["pos_frequencies"] = original_state_dict.pop("pos_frequencies")
print("Remaining Keys:", original_state_dict.keys())
return new_state_dict
def convert_mochi_decoder_state_dict_to_diffusers(ckpt_path):
original_state_dict = load_file(ckpt_path, device="cpu")
new_state_dict = {}
prefix = "decoder."
# Convert conv_in
new_state_dict[f"{prefix}conv_in.weight"] = original_state_dict["blocks.0.0.weight"]
new_state_dict[f"{prefix}conv_in.bias"] = original_state_dict["blocks.0.0.bias"]
# Convert block_in (MochiMidBlock3D)
for i in range(3): # layers_per_block[-1] = 3
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
f"blocks.0.{i+1}.stack.0.weight"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
f"blocks.0.{i+1}.stack.0.bias"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = original_state_dict[
f"blocks.0.{i+1}.stack.2.weight"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = original_state_dict[
f"blocks.0.{i+1}.stack.2.bias"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
f"blocks.0.{i+1}.stack.3.weight"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
f"blocks.0.{i+1}.stack.3.bias"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = original_state_dict[
f"blocks.0.{i+1}.stack.5.weight"
]
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = original_state_dict[
f"blocks.0.{i+1}.stack.5.bias"
]
# Convert up_blocks (MochiUpBlock3D)
up_block_layers = [6, 4, 3] # layers_per_block[-2], layers_per_block[-3], layers_per_block[-4]
for block in range(3):
for i in range(up_block_layers[block]):
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
]
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = original_state_dict[
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
]
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = original_state_dict[f"blocks.{block+1}.proj.weight"]
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = original_state_dict[f"blocks.{block+1}.proj.bias"]
# Convert block_out (MochiMidBlock3D)
for i in range(3): # layers_per_block[0] = 3
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = original_state_dict[
f"blocks.4.{i}.stack.0.weight"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = original_state_dict[
f"blocks.4.{i}.stack.0.bias"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = original_state_dict[
f"blocks.4.{i}.stack.2.weight"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = original_state_dict[
f"blocks.4.{i}.stack.2.bias"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = original_state_dict[
f"blocks.4.{i}.stack.3.weight"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = original_state_dict[
f"blocks.4.{i}.stack.3.bias"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = original_state_dict[
f"blocks.4.{i}.stack.5.weight"
]
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = original_state_dict[
f"blocks.4.{i}.stack.5.bias"
]
# Convert conv_out (Conv1x1)
new_state_dict[f"{prefix}conv_out.weight"] = original_state_dict["output_proj.weight"]
new_state_dict[f"{prefix}conv_out.bias"] = original_state_dict["output_proj.bias"]
return new_state_dict
def main(args):
if args.dtype is None:
dtype = None
if args.dtype == "fp16":
dtype = torch.float16
elif args.dtype == "bf16":
dtype = torch.bfloat16
elif args.dtype == "fp32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {args.dtype}")
transformer = None
vae = None
if args.transformer_checkpoint_path is not None:
converted_transformer_state_dict = convert_mochi_transformer_checkpoint_to_diffusers(
args.transformer_checkpoint_path
)
transformer = MochiTransformer3DModel()
transformer.load_state_dict(converted_transformer_state_dict, strict=True)
if dtype is not None:
transformer = transformer.to(dtype=dtype)
if args.vae_checkpoint_path is not None:
vae = AutoencoderKLMochi(latent_channels=12, out_channels=3)
converted_vae_state_dict = convert_mochi_decoder_state_dict_to_diffusers(args.vae_checkpoint_path)
vae.load_state_dict(converted_vae_state_dict, strict=True)
if dtype is not None:
vae = vae.to(dtype=dtype)
text_encoder_id = "google/t5-v1_1-xxl"
tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH)
text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir)
# Apparently, the conversion does not work anymore without this :shrug:
for param in text_encoder.parameters():
param.data = param.data.contiguous()
pipe = MochiPipeline(
scheduler=FlowMatchEulerDiscreteScheduler(),
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB", push_to_hub=args.push_to_hub)
if __name__ == "__main__":
main(args)

View File

@@ -81,6 +81,7 @@ else:
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLCogVideoX",
"AutoencoderKLMochi",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
@@ -100,6 +101,7 @@ else:
"Kandinsky3UNet",
"LatteTransformer3DModel",
"LuminaNextDiT2DModel",
"MochiTransformer3DModel",
"ModelMixin",
"MotionAdapter",
"MultiAdapter",
@@ -308,6 +310,7 @@ else:
"LuminaText2ImgPipeline",
"MarigoldDepthPipeline",
"MarigoldNormalsPipeline",
"MochiPipeline",
"MusicLDMPipeline",
"PaintByExamplePipeline",
"PIAPipeline",
@@ -560,6 +563,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
@@ -579,6 +583,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
ModelMixin,
MotionAdapter,
MultiAdapter,
@@ -766,6 +771,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaText2ImgPipeline,
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
MochiPipeline,
MusicLDMPipeline,
PaintByExamplePipeline,
PIAPipeline,

View File

@@ -29,6 +29,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_cogvideox"] = ["AutoencoderKLCogVideoX"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
@@ -56,6 +57,7 @@ if is_torch_available():
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
@@ -82,6 +84,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLCogVideoX,
AutoencoderKLMochi,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
@@ -106,6 +109,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanDiT2DModel,
LatteTransformer3DModel,
LuminaNextDiT2DModel,
MochiTransformer3DModel,
PixArtTransformer2DModel,
PriorTransformer,
SD3Transformer2DModel,

View File

@@ -134,14 +134,18 @@ class SwiGLU(nn.Module):
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
def __init__(self, dim_in: int, dim_out: int, bias: bool = True, flip_gate: bool = False):
super().__init__()
self.flip_gate = flip_gate
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)
self.activation = nn.SiLU()
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
hidden_states, gate = hidden_states.chunk(2, dim=-1)
if self.flip_gate:
hidden_states, gate = gate, hidden_states
return hidden_states * self.activation(gate)

View File

@@ -1206,6 +1206,7 @@ class FeedForward(nn.Module):
final_dropout: bool = False,
inner_dim=None,
bias: bool = True,
flip_gate: bool = False,
):
super().__init__()
if inner_dim is None:
@@ -1221,7 +1222,7 @@ class FeedForward(nn.Module):
elif activation_fn == "geglu-approximate":
act_fn = ApproximateGELU(dim, inner_dim, bias=bias)
elif activation_fn == "swiglu":
act_fn = SwiGLU(dim, inner_dim, bias=bias)
act_fn = SwiGLU(dim, inner_dim, bias=bias, flip_gate=flip_gate)
self.net = nn.ModuleList([])
# project in

View File

@@ -120,6 +120,7 @@ class Attention(nn.Module):
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
out_context_dim: int = None,
context_pre_only=None,
pre_only=False,
elementwise_affine: bool = True,
@@ -142,6 +143,7 @@ class Attention(nn.Module):
self.dropout = dropout
self.fused_projections = False
self.out_dim = out_dim if out_dim is not None else query_dim
self.out_context_dim = out_context_dim if out_context_dim is not None else query_dim
self.context_pre_only = context_pre_only
self.pre_only = pre_only
@@ -241,7 +243,7 @@ class Attention(nn.Module):
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
self.to_add_out = nn.Linear(self.inner_dim, self.out_context_dim, bias=out_bias)
if qk_norm is not None and added_kv_proj_dim is not None:
if qk_norm == "fp32_layer_norm":
@@ -1792,6 +1794,7 @@ class FluxAttnProcessor2_0:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
@@ -3078,6 +3081,94 @@ class LuminaAttnProcessor2_0:
return hidden_states
class MochiAttnProcessor2_0:
"""Attention processor used in Mochi."""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("MochiAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
encoder_query = attn.add_q_proj(encoder_hidden_states)
encoder_key = attn.add_k_proj(encoder_hidden_states)
encoder_value = attn.add_v_proj(encoder_hidden_states)
encoder_query = encoder_query.unflatten(2, (attn.heads, -1))
encoder_key = encoder_key.unflatten(2, (attn.heads, -1))
encoder_value = encoder_value.unflatten(2, (attn.heads, -1))
if attn.norm_added_q is not None:
encoder_query = attn.norm_added_q(encoder_query)
if attn.norm_added_k is not None:
encoder_key = attn.norm_added_k(encoder_key)
if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float()
x_odd = x[..., 1::2].float()
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2)
query = apply_rotary_emb(query, *image_rotary_emb)
key = apply_rotary_emb(key, *image_rotary_emb)
query, key, value = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
encoder_query, encoder_key, encoder_value = (
encoder_query.transpose(1, 2),
encoder_key.transpose(1, 2),
encoder_value.transpose(1, 2),
)
sequence_length = query.size(2)
encoder_sequence_length = encoder_query.size(2)
query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1
)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if hasattr(attn, "to_add_out"):
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
return hidden_states, encoder_hidden_states
class FusedAttnProcessor2_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). It uses

View File

@@ -1,6 +1,7 @@
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_tiny import AutoencoderTiny

View File

@@ -0,0 +1,700 @@
# Copyright 2024 The Mochi team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# YiYi to-do: replace this with nn.Conv3d
class Conv1x1(nn.Linear):
"""*1x1 Conv implemented with a linear layer."""
def __init__(self, in_features: int, out_features: int, *args, **kwargs):
super().__init__(in_features, out_features, *args, **kwargs)
def forward(self, x: torch.Tensor):
"""Forward pass.
Args:
x: Input tensor. Shape: [B, C, *] or [B, *, C].
Returns:
x: Output tensor. Shape: [B, C', *] or [B, *, C'].
"""
x = x.movedim(1, -1)
x = super().forward(x)
x = x.movedim(-1, 1)
return x
class MochiChunkedCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in Mochi Model.
It also supports memory-efficient chunked 3D convolutions.
Args:
in_channels (`int`): Number of channels in the input tensor.
out_channels (`int`): Number of output channels produced by the convolution.
kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
stride (`int` or `Tuple[int, int, int]`, defaults to `1`): Stride of the convolution.
padding_mode (`str`, defaults to `"replicate"`): Padding mode.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]],
padding_mode: str = "replicate",
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
_, height_kernel_size, width_kernel_size = kernel_size
self.padding_mode = padding_mode
height_pad = (height_kernel_size - 1) // 2
width_pad = (width_kernel_size - 1) // 2
self.conv = nn.Conv3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=(1, 1, 1),
padding=(0, height_pad, width_pad),
padding_mode=padding_mode,
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
time_kernel_size = self.conv.kernel_size[0]
context_size = time_kernel_size - 1
time_casual_padding = (0, 0, 0, 0, context_size, 0)
hidden_states = F.pad(hidden_states, time_casual_padding, mode=self.padding_mode)
# Memory-efficient chunked operation
memory_count = torch.prod(torch.tensor(hidden_states.shape)).item() * 2 / 1024**3
# YiYI Notes: testing only!! please remove
memory_count = 3
# YiYI Notes: this number 2 should be a config: max_memory_chunk_size (2 is 2GB)
if memory_count > 2:
part_num = int(memory_count / 2) + 1
num_frames = hidden_states.shape[2]
frames_idx = torch.arange(context_size, num_frames)
frames_chunks_idx = torch.chunk(frames_idx, part_num, dim=0)
output_chunks = []
for frames_chunk_idx in frames_chunks_idx:
frames_s = frames_chunk_idx[0] - context_size
frames_e = frames_chunk_idx[-1] + 1
frames_chunk = hidden_states[:, :, frames_s:frames_e, :, :]
output_chunk = self.conv(frames_chunk)
output_chunks.append(output_chunk) # Append each output chunk to the list
# Concatenate all output chunks along the temporal dimension
hidden_states = torch.cat(output_chunks, dim=2)
return hidden_states
else:
return self.conv(hidden_states)
class MochiChunkedGroupNorm3D(nn.Module):
r"""
Applies per-frame group normalization for 5D video inputs. It also supports memory-efficient chunked group
normalization.
Args:
num_channels (int): Number of channels expected in input
num_groups (int, optional): Number of groups to separate the channels into. Default: 32
affine (bool, optional): If True, this module has learnable affine parameters. Default: True
chunk_size (int, optional): Size of each chunk for processing. Default: 8
"""
def __init__(
self,
num_channels: int,
num_groups: int = 32,
affine: bool = True,
chunk_size: int = 8,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=num_channels, num_groups=num_groups, affine=affine)
self.chunk_size = chunk_size
def forward(self, x: torch.Tensor = None) -> torch.Tensor:
batch_size, channels, num_frames, height, width = x.shape
x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
output = torch.cat([self.norm_layer(chunk) for chunk in x.split(self.chunk_size, dim=0)], dim=0)
output = output.view(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
return output
class MochiResnetBlock3D(nn.Module):
r"""
A 3D ResNet block used in the Mochi model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
non_linearity (`str`, defaults to `"swish"`):
Activation function to use.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
act_fn: str = "swish",
):
super().__init__()
out_channels = out_channels or in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.nonlinearity = get_activation(act_fn)
self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels)
self.conv1 = MochiChunkedCausalConv3d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1
)
self.norm2 = MochiChunkedGroupNorm3D(num_channels=out_channels)
self.conv2 = MochiChunkedCausalConv3d(
in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1
)
def forward(
self,
inputs: torch.Tensor,
) -> torch.Tensor:
hidden_states = inputs
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = hidden_states + inputs
return hidden_states
class MochiUpBlock3D(nn.Module):
r"""
An upsampling block used in the Mochi model.
Args:
in_channels (`int`):
Number of input channels.
out_channels (`int`, *optional*):
Number of output channels. If None, defaults to `in_channels`.
num_layers (`int`, defaults to `1`):
Number of resnet blocks in the block.
temporal_expansion (`int`, defaults to `2`):
Temporal expansion factor.
spatial_expansion (`int`, defaults to `2`):
Spatial expansion factor.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
temporal_expansion: int = 2,
spatial_expansion: int = 2,
):
super().__init__()
self.temporal_expansion = temporal_expansion
self.spatial_expansion = spatial_expansion
resnets = []
for i in range(num_layers):
resnets.append(
MochiResnetBlock3D(
in_channels=in_channels,
)
)
self.resnets = nn.ModuleList(resnets)
self.proj = Conv1x1(
in_channels,
out_channels * temporal_expansion * (spatial_expansion**2),
)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
r"""Forward method of the `MochiUpBlock3D` class."""
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet),
hidden_states,
)
else:
hidden_states = resnet(hidden_states)
hidden_states = self.proj(hidden_states)
# Calculate new shape
B, C, T, H, W = hidden_states.shape
st = self.temporal_expansion
sh = self.spatial_expansion
sw = self.spatial_expansion
new_C = C // (st * sh * sw)
# Reshape and permute
hidden_states = hidden_states.view(B, new_C, st, sh, sw, T, H, W)
hidden_states = hidden_states.permute(0, 1, 5, 2, 6, 3, 7, 4)
hidden_states = hidden_states.contiguous().view(B, new_C, T * st, H * sh, W * sw)
if self.temporal_expansion > 1:
# Drop the first self.temporal_expansion - 1 frames.
hidden_states = hidden_states[:, :, self.temporal_expansion - 1 :]
return hidden_states
class MochiMidBlock3D(nn.Module):
r"""
A middle block used in the Mochi model.
Args:
in_channels (`int`):
Number of input channels.
num_layers (`int`, defaults to `3`):
Number of resnet blocks in the block.
"""
def __init__(
self,
in_channels: int, # 768
num_layers: int = 3,
):
super().__init__()
resnets = []
for _ in range(num_layers):
resnets.append(MochiResnetBlock3D(in_channels=in_channels))
self.resnets = nn.ModuleList(resnets)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
r"""Forward method of the `MochiMidBlock3D` class."""
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states)
else:
hidden_states = resnet(hidden_states)
return hidden_states
class MochiDecoder3D(nn.Module):
r"""
The `MochiDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
sample.
Args:
in_channels (`int`, *optional*):
The number of input channels.
out_channels (`int`, *optional*):
The number of output channels.
block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(128, 256, 512, 768)`):
The number of output channels for each block.
layers_per_block (`Tuple[int, ...]`, *optional*, defaults to `(3, 3, 4, 6, 3)`):
The number of resnet blocks for each block.
temporal_expansions (`Tuple[int, ...]`, *optional*, defaults to `(1, 2, 3)`):
The temporal expansion factor for each of the up blocks.
spatial_expansions (`Tuple[int, ...]`, *optional*, defaults to `(2, 2, 2)`):
The spatial expansion factor for each of the up blocks.
non_linearity (`str`, *optional*, defaults to `"swish"`):
The non-linearity to use in the decoder.
"""
def __init__(
self,
in_channels: int, # 12
out_channels: int, # 3
block_out_channels: Tuple[int, ...] = (128, 256, 512, 768),
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
act_fn: str = "swish",
):
super().__init__()
self.nonlinearity = get_activation(act_fn)
self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1))
self.block_in = MochiMidBlock3D(
in_channels=block_out_channels[-1],
num_layers=layers_per_block[-1],
)
self.up_blocks = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
up_block = MochiUpBlock3D(
in_channels=block_out_channels[-i - 1],
out_channels=block_out_channels[-i - 2],
num_layers=layers_per_block[-i - 2],
temporal_expansion=temporal_expansions[-i - 1],
spatial_expansion=spatial_expansions[-i - 1],
)
self.up_blocks.append(up_block)
self.block_out = MochiMidBlock3D(
in_channels=block_out_channels[0],
num_layers=layers_per_block[0],
)
self.conv_out = Conv1x1(block_out_channels[0], out_channels)
self.gradient_checkpointing = False
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Forward method of the `MochiDecoder3D` class."""
hidden_states = self.conv_in(hidden_states)
# 1. Mid
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_forward(*inputs):
return module(*inputs)
return create_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(self.block_in), hidden_states)
for up_block in self.up_blocks:
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), hidden_states)
else:
hidden_states = self.block_in(hidden_states)
for up_block in self.up_blocks:
hidden_states = up_block(hidden_states)
hidden_states = self.block_out(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class AutoencoderKLMochi(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in
[Mochi 1 preview](https://github.com/genmoai/models).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
Tuple of block output channels.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
scaling_factor (`float`, *optional*, defaults to `1.15258426`):
The component-wise standard deviation of the trained latent space computed using the first batch of the
training set. This is used to scale the latent space to have unit variance when training the diffusion
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["MochiResnetBlock3D"]
@register_to_config
def __init__(
self,
out_channels: int = 3,
block_out_channels: Tuple[int] = (128, 256, 512, 768),
latent_channels: int = 12,
layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3),
act_fn: str = "silu",
temporal_expansions: Tuple[int, ...] = (1, 2, 3),
spatial_expansions: Tuple[int, ...] = (2, 2, 2),
latents_mean: Tuple[float, ...] = (
-0.06730895953510081,
-0.038011381506090416,
-0.07477820912866141,
-0.05565264470995561,
0.012767231469026969,
-0.04703542746246419,
0.043896967884726704,
-0.09346305707025976,
-0.09918314763016893,
-0.008729793427399178,
-0.011931556316503654,
-0.0321993391887285,
),
latents_std: Tuple[float, ...] = (
0.9263795028493863,
0.9248894543193766,
0.9393059390890617,
0.959253732819592,
0.8244560132752793,
0.917259975397747,
0.9294154431013696,
1.3720942357788521,
0.881393668867029,
0.9168315692124348,
0.9185249279345552,
0.9274757570805041,
),
scaling_factor: float = 1.0,
):
super().__init__()
self.decoder = MochiDecoder3D(
in_channels=latent_channels,
out_channels=out_channels,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
temporal_expansions=temporal_expansions,
spatial_expansions=spatial_expansions,
act_fn=act_fn,
)
self.use_slicing = False
self.use_tiling = False
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, MochiDecoder3D):
module.gradient_checkpointing = value
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`int`, *optional*):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`int`, *optional*):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(
self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
)
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self.decoder(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self.decoder(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
batch_size, num_channels, num_frames, height, width = z.shape
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
frame_batch_size = self.num_latent_frames_batch_size
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
num_batches = max(num_frames // frame_batch_size, 1)
conv_cache = None
time = []
for k in range(num_batches):
remaining_frames = num_frames % frame_batch_size
start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
end_frame = frame_batch_size * (k + 1) + remaining_frames
tile = z[
:,
:,
start_frame:end_frame,
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
if self.post_quant_conv is not None:
tile = self.post_quant_conv(tile)
tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -1302,6 +1302,41 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
return conditioning
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
def __init__(
self,
embedding_dim: int,
pooled_projection_dim: int,
text_embed_dim: int,
time_embed_dim: int = 256,
num_attention_heads: int = 8,
) -> None:
super().__init__()
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
self.pooler = MochiAttentionPool(
num_attention_heads=num_attention_heads, embed_dim=text_embed_dim, output_dim=embedding_dim
)
self.caption_proj = nn.Linear(text_embed_dim, pooled_projection_dim)
def forward(
self,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_attention_mask: torch.Tensor,
hidden_dtype: Optional[torch.dtype] = None,
):
time_proj = self.time_proj(timestep)
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
caption_proj = self.caption_proj(encoder_hidden_states)
conditioning = time_emb + pooled_projections
return conditioning, caption_proj
class TextTimeEmbedding(nn.Module):
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
super().__init__()
@@ -1430,6 +1465,88 @@ class AttentionPooling(nn.Module):
return a[:, 0, :] # cls_token
class MochiAttentionPool(nn.Module):
def __init__(
self,
num_attention_heads: int,
embed_dim: int,
output_dim: Optional[int] = None,
) -> None:
super().__init__()
self.output_dim = output_dim or embed_dim
self.num_attention_heads = num_attention_heads
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
self.to_q = nn.Linear(embed_dim, embed_dim)
self.to_out = nn.Linear(embed_dim, self.output_dim)
@staticmethod
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
"""
Pool tokens in x using mask.
NOTE: We assume x does not require gradients.
Args:
x: (B, L, D) tensor of tokens.
mask: (B, L) boolean tensor indicating which tokens are not padding.
Returns:
pooled: (B, D) tensor of pooled tokens.
"""
assert x.size(1) == mask.size(1) # Expected mask to have same length as tokens.
assert x.size(0) == mask.size(0) # Expected mask to have same batch size as tokens.
mask = mask[:, :, None].to(dtype=x.dtype)
mask = mask / mask.sum(dim=1, keepdim=True).clamp(min=1)
pooled = (x * mask).sum(dim=1, keepdim=keepdim)
return pooled
def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
r"""
Args:
x (`torch.Tensor`):
Tensor of shape `(B, S, D)` of input tokens.
mask (`torch.Tensor`):
Boolean ensor of shape `(B, S)` indicating which tokens are not padding.
Returns:
`torch.Tensor`:
`(B, D)` tensor of pooled tokens.
"""
D = x.size(2)
# Construct attention mask, shape: (B, 1, num_queries=1, num_keys=1+L).
attn_mask = mask[:, None, None, :].bool() # (B, 1, 1, L).
attn_mask = F.pad(attn_mask, (1, 0), value=True) # (B, 1, 1, 1+L).
# Average non-padding token features. These will be used as the query.
x_pool = self.pool_tokens(x, mask, keepdim=True) # (B, 1, D)
# Concat pooled features to input sequence.
x = torch.cat([x_pool, x], dim=1) # (B, L+1, D)
# Compute queries, keys, values. Only the mean token is used to create a query.
kv = self.to_kv(x) # (B, L+1, 2 * D)
q = self.to_q(x[:, 0]) # (B, D)
# Extract heads.
head_dim = D // self.num_attention_heads
kv = kv.unflatten(2, (2, self.num_attention_heads, head_dim)) # (B, 1+L, 2, H, head_dim)
kv = kv.transpose(1, 3) # (B, H, 2, 1+L, head_dim)
k, v = kv.unbind(2) # (B, H, 1+L, head_dim)
q = q.unflatten(1, (self.num_attention_heads, head_dim)) # (B, H, head_dim)
q = q.unsqueeze(2) # (B, H, 1, head_dim)
# Compute attention.
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
# Concatenate heads and run output.
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)
x = self.to_out(x)
return x
def get_fourier_embeds_from_boundingbox(embed_dim, box):
"""
Args:

View File

@@ -237,6 +237,33 @@ class LuminaRMSNormZero(nn.Module):
return x, gate_msa, scale_mlp, gate_mlp
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
return hidden_states, gate_msa, scale_mlp, gate_mlp
class AdaLayerNormSingle(nn.Module):
r"""
Norm layer adaptive layer norm single (adaLN-single).
@@ -358,20 +385,21 @@ class LuminaLayerNormContinuous(nn.Module):
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
if norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
# linear_2
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(
embedding_dim,
out_dim,
bias=bias,
)
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,

View File

@@ -16,5 +16,6 @@ if is_torch_available():
from .transformer_2d import Transformer2DModel
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,390 @@
# Copyright 2024 The Genmo team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class MochiTransformerBlock(nn.Module):
r"""
Transformer block used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
activation_fn (`str`, defaults to `"swiglu"`):
Activation function to use in feed-forward.
context_pre_only (`bool`, defaults to `False`):
Whether or not to process context-related conditions with additional layers.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
pooled_projection_dim: int,
qk_norm: str = "rms_norm",
activation_fn: str = "swiglu",
context_pre_only: bool = False,
eps: float = 1e-6,
) -> None:
super().__init__()
self.context_pre_only = context_pre_only
self.ff_inner_dim = (4 * dim * 2) // 3
self.ff_context_inner_dim = (4 * pooled_projection_dim * 2) // 3
self.norm1 = MochiRMSNormZero(dim, 4 * dim, eps=eps, elementwise_affine=False)
if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else:
self.norm1_context = LuminaLayerNormContinuous(
embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim,
eps=eps,
elementwise_affine=False,
norm_type="rms_norm",
out_dim=None,
)
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
heads=num_attention_heads,
dim_head=attention_head_dim,
bias=False,
qk_norm=qk_norm,
added_kv_proj_dim=pooled_projection_dim,
added_proj_bias=False,
out_dim=dim,
out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(),
eps=eps,
elementwise_affine=True,
)
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(
dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False, flip_gate=True
)
self.ff_context = None
if not context_pre_only:
self.ff_context = FeedForward(
pooled_projection_dim,
inner_dim=self.ff_context_inner_dim,
activation_fn=activation_fn,
bias=False,
flip_gate=True,
)
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
if not self.context_pre_only:
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
encoder_hidden_states, temb
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states
) * torch.tanh(enc_gate_msa).unsqueeze(1)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
enc_gate_mlp
).unsqueeze(1)
return hidden_states, encoder_hidden_states
class MochiRoPE(nn.Module):
r"""
RoPE implementation used in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
base_height (`int`, defaults to `192`):
Base height used to compute interpolation scale for rotary positional embeddings.
base_width (`int`, defaults to `192`):
Base width used to compute interpolation scale for rotary positional embeddings.
"""
def __init__(self, base_height: int = 192, base_width: int = 192) -> None:
super().__init__()
self.target_area = base_height * base_width
def _centers(self, start, stop, num, device, dtype) -> torch.Tensor:
edges = torch.linspace(start, stop, num + 1, device=device, dtype=dtype)
return (edges[:-1] + edges[1:]) / 2
def _get_positions(
self,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
scale = (self.target_area / (height * width)) ** 0.5
t = torch.arange(num_frames, device=device, dtype=dtype)
h = self._centers(-height * scale / 2, height * scale / 2, height, device, dtype)
w = self._centers(-width * scale / 2, width * scale / 2, width, device, dtype)
grid_t, grid_h, grid_w = torch.meshgrid(t, h, w, indexing="ij")
positions = torch.stack([grid_t, grid_h, grid_w], dim=-1).view(-1, 3)
return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin
def forward(
self,
pos_frequencies: torch.Tensor,
num_frames: int,
height: int,
width: int,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pos = self._get_positions(num_frames, height, width, device, dtype)
rope_cos, rope_sin = self._create_rope(pos_frequencies, pos)
return rope_cos, rope_sin
@maybe_allow_in_graph
class MochiTransformer3DModel(ModelMixin, ConfigMixin):
r"""
A Transformer model for video-like data introduced in [Mochi](https://huggingface.co/genmo/mochi-1-preview).
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
num_layers (`int`, defaults to `48`):
The number of layers of Transformer blocks to use.
in_channels (`int`, defaults to `12`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `None`):
The number of channels in the output.
qk_norm (`str`, defaults to `"rms_norm"`):
The normalization layer to use.
text_embed_dim (`int`, defaults to `4096`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `256`):
Output dimension of timestep embeddings.
activation_fn (`str`, defaults to `"swiglu"`):
Activation function to use in feed-forward.
max_sequence_length (`int`, defaults to `256`):
The maximum sequence length of text embeddings supported.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
patch_size: int = 2,
num_attention_heads: int = 24,
attention_head_dim: int = 128,
num_layers: int = 48,
pooled_projection_dim: int = 1536,
in_channels: int = 12,
out_channels: Optional[int] = None,
qk_norm: str = "rms_norm",
text_embed_dim: int = 4096,
time_embed_dim: int = 256,
activation_fn: str = "swiglu",
max_sequence_length: int = 256,
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
pos_embed_type=None,
)
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
embedding_dim=inner_dim,
pooled_projection_dim=pooled_projection_dim,
text_embed_dim=text_embed_dim,
time_embed_dim=time_embed_dim,
num_attention_heads=8,
)
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
self.rope = MochiRoPE()
self.transformer_blocks = nn.ModuleList(
[
MochiTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
pooled_projection_dim=pooled_projection_dim,
qk_norm=qk_norm,
activation_fn=activation_fn,
context_pre_only=i == num_layers - 1,
)
for i in range(num_layers)
]
)
self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.patch_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
image_rotary_emb = self.rope(
self.pos_frequencies,
num_frames,
post_patch_height,
post_patch_width,
device=hidden_states.device,
dtype=torch.float32,
)
for i, block in enumerate(self.transformer_blocks):
if self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
)
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -246,6 +246,7 @@ else:
"MarigoldNormalsPipeline",
]
)
_import_structure["mochi"] = ["MochiPipeline"]
_import_structure["musicldm"] = ["MusicLDMPipeline"]
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
_import_structure["pia"] = ["PIAPipeline"]
@@ -569,6 +570,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MarigoldDepthPipeline,
MarigoldNormalsPipeline,
)
from .mochi import MochiPipeline
from .musicldm import MusicLDMPipeline
from .pag import (
AnimateDiffPAGPipeline,

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_mochi"] = ["MochiPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_mochi import MochiPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,712 @@
# Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
is_torch_xla_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import MochiPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import MochiPipeline
>>> from diffusers.utils import export_to_video
>>> pipe = MochiPipeline.from_pretrained("genmo/mochi-1-preview", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
>>> export_to_video(frames, "mochi.mp4")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# from: https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
def linear_quadratic_schedule(num_steps, threshold_noise, linear_steps=None):
if linear_steps is None:
linear_steps = num_steps // 2
linear_sigma_schedule = [i * threshold_noise / linear_steps for i in range(linear_steps)]
threshold_noise_step_diff = linear_steps - threshold_noise * num_steps
quadratic_steps = num_steps - linear_steps
quadratic_coef = threshold_noise_step_diff / (linear_steps * quadratic_steps**2)
linear_coef = threshold_noise / linear_steps - 2 * threshold_noise_step_diff / (quadratic_steps**2)
const = quadratic_coef * (linear_steps**2)
quadratic_sigma_schedule = [
quadratic_coef * (i**2) + linear_coef * i + const for i in range(linear_steps, num_steps)
]
sigma_schedule = linear_sigma_schedule + quadratic_sigma_schedule + [1.0]
sigma_schedule = [1.0 - x for x in sigma_schedule]
return sigma_schedule
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class MochiPipeline(DiffusionPipeline):
r"""
The mochi pipeline for text-to-video generation.
Reference: https://github.com/genmoai/models
Args:
transformer ([`MochiTransformer3DModel`]):
Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer (`T5TokenizerFast`):
Second Tokenizer of class
[T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
self.tokenizer_max_length = (
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
)
self.default_height = 480
self.default_width = 848
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds(
self,
prompt: Union[str, List[str]] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.bool().to(device)
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because `max_sequence_length` is set to "
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
prompt_attention_mask = prompt_attention_mask.view(batch_size, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_videos_per_prompt, 1)
return prompt_embeds, prompt_attention_mask
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
negative_prompt: Optional[Union[str, List[str]]] = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 256,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_prompt_attention_mask = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
def check_inputs(
self,
prompt,
height,
width,
num_frames,
callback_on_step_end_tensor_inputs=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (num_frames - 1) % self.vae_temporal_scale_factor != 0:
raise ValueError(f"Expected `num_frames - 1` to be divisible by {self.vae_temporal_scale_factor=}")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
batch_size,
num_channels_latents,
height,
width,
num_frames,
dtype,
device,
generator,
latents=None,
):
height = height // self.vae_spatial_scale_factor
width = width // self.vae_spatial_scale_factor
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
shape = (batch_size, num_channels_latents, num_frames, height, width)
if latents is not None:
return latents.to(device=device, dtype=dtype)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: Optional[Union[str, List[str]]] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_frames: int = 16,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image. This is set to 1024 by default for the best results.
num_frames (`int`, defaults to 16):
The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, defaults to `4.5`):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of videos to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
prompt_attention_mask (`torch.Tensor`, *optional*):
Pre-generated attention mask for text embeddings.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not
provided, negative_prompt_embeds will be generated from `negative_prompt` input argument.
negative_prompt_attention_mask (`torch.FloatTensor`, *optional*):
Pre-generated attention mask for negative text embeddings.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.mochi.MochiPipelineOutput`] instead of a plain tuple.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to `256`):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.mochi.MochiPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.mochi.MochiPipelineOutput`] is returned, otherwise a `tuple`
is returned where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.default_height
width = width or self.default_width
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
height=height,
width=width,
num_frames=num_frames,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
self._guidance_scale = guidance_scale
self._interrupt = False
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# 3. Prepare text embeddings
(
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_videos_per_prompt,
num_channels_latents,
height,
width,
num_frames,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
threshold_noise = 0.025
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
timesteps,
sigmas,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if output_type == "latent":
video = latents
else:
video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return MochiPipelineOutput(frames=video)

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class MochiPipelineOutput(BaseOutput):
r"""
Output class for CogVideo pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

View File

@@ -71,6 +71,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
max_shift: Optional[float] = 1.15,
base_image_seq_len: Optional[int] = 256,
max_image_seq_len: Optional[int] = 4096,
invert_sigmas: bool = False,
):
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
@@ -204,9 +205,15 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
timesteps = sigmas * self.config.num_train_timesteps
self.timesteps = timesteps.to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
if self.config.invert_sigmas:
sigmas = 1.0 - sigmas
timesteps = sigmas * self.config.num_train_timesteps
sigmas = torch.cat([sigmas, torch.ones(1, device=sigmas.device)])
else:
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
self.timesteps = timesteps.to(device=device)
self.sigmas = sigmas
self._step_index = None
self._begin_index = None

View File

@@ -62,6 +62,21 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLMochi(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
_backends = ["torch"]
@@ -347,6 +362,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class MochiTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ModelMixin(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1037,6 +1037,21 @@ class MarigoldNormalsPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class MochiPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class MusicLDMPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,80 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import MochiTransformer3DModel
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class MochiTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = MochiTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
num_frames = 2
height = 16
width = 16
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
}
@property
def input_shape(self):
return (4, 2, 16, 16)
@property
def output_shape(self):
return (4, 2, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"patch_size": 2,
"num_attention_heads": 2,
"attention_head_dim": 8,
"num_layers": 2,
"pooled_projection_dim": 16,
"in_channels": 4,
"out_channels": None,
"qk_norm": "rms_norm",
"text_embed_dim": 16,
"time_embed_dim": 4,
"activation_fn": "swiglu",
"max_sequence_length": 16,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict