Compare commits

...

10 Commits

Author SHA1 Message Date
Dhruv Nair
dded24364c update 2024-11-25 08:50:27 +01:00
Dhruv Nair
3ffa711db1 update 2024-11-25 08:34:33 +01:00
Dhruv Nair
66a5f59ca1 update 2024-11-25 08:25:54 +01:00
Dhruv Nair
1782d0241a update 2024-11-25 08:11:11 +01:00
Dhruv Nair
fcc59d01a9 update 2024-11-23 17:15:18 +01:00
Dhruv Nair
21b09979dc update 2024-11-22 13:21:32 +01:00
Dhruv Nair
79380ca719 update 2024-11-20 19:41:08 +01:00
Dhruv Nair
10275feacd update 2024-11-20 13:57:41 +01:00
Dhruv Nair
30dd9f6845 update 2024-11-18 17:50:51 +01:00
Dhruv Nair
27f81bd54f update 2024-11-18 17:30:24 +01:00
5 changed files with 283 additions and 107 deletions

View File

@@ -16,6 +16,7 @@ import math
from typing import Callable, List, Optional, Tuple, Union from typing import Callable, List, Optional, Tuple, Union
import torch import torch
from torch._higher_order_ops.flex_attention import sdpa_dense
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
@@ -3554,11 +3555,11 @@ class MochiAttnProcessor2_0:
if image_rotary_emb is not None: if image_rotary_emb is not None:
def apply_rotary_emb(x, freqs_cos, freqs_sin): def apply_rotary_emb(x, freqs_cos, freqs_sin):
x_even = x[..., 0::2].float() x_even = x[..., 0::2]
x_odd = x[..., 1::2].float() x_odd = x[..., 1::2]
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype) cos = (x_even * freqs_cos.float() - x_odd * freqs_sin.float()).to(x.dtype)
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype) sin = (x_even * freqs_sin.float() + x_odd * freqs_cos.float()).to(x.dtype)
return torch.stack([cos, sin], dim=-1).flatten(-2) return torch.stack([cos, sin], dim=-1).flatten(-2)
@@ -3572,16 +3573,39 @@ class MochiAttnProcessor2_0:
encoder_value.transpose(1, 2), encoder_value.transpose(1, 2),
) )
sequence_length = query.size(2) batch_size, heads, sequence_length, dim = query.shape
encoder_sequence_length = encoder_query.size(2) encoder_sequence_length = encoder_query.shape[2]
total_length = sequence_length + encoder_sequence_length
query = torch.cat([query, encoder_query], dim=2) query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2) key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2) value = torch.cat([value, encoder_value], dim=2)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) # Zero out tokens based on the attention mask
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) # query = query * attention_mask[:, None, :, None]
hidden_states = hidden_states.to(query.dtype) # key = key * attention_mask[:, None, :, None]
# value = value * attention_mask[:, None, :, None]
query = query.view(1, query.size(1), -1, query.size(-1))
key = key.view(1, key.size(1), -1, key.size(-1))
value = value.view(1, value.size(1), -1, key.size(-1))
select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
query = torch.index_select(query, 2, select_index)
key = torch.index_select(key, 2, select_index)
value = torch.index_select(value, 2, select_index)
from torch.nn.attention import SDPBackend, sdpa_kernel
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0)
output = torch.zeros(
batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype
)
output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states)
hidden_states = output.view(batch_size, total_length, dim * heads)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1 (sequence_length, encoder_sequence_length), dim=1

View File

@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
height, width = latent.shape[-2:] height, width = latent.shape[-2:]
else: else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent) latent = self.proj(latent)
if self.flatten: if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

View File

@@ -256,7 +256,9 @@ class MochiRMSNormZero(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb)) emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) scale_msa = scale_msa.float()
_hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None])
hidden_states = _hidden_states.to(hidden_states.dtype)
return hidden_states, gate_msa, scale_mlp, gate_mlp return hidden_states, gate_msa, scale_mlp, gate_mlp
@@ -530,15 +532,14 @@ class RMSNorm(nn.Module):
def forward(self, hidden_states): def forward(self, hidden_states):
input_dtype = hidden_states.dtype input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps) hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps)
if self.weight is not None: if self.weight is not None:
# convert into half-precision if necessary # convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]: if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
else: hidden_states = hidden_states.to(input_dtype)
hidden_states = hidden_states.to(input_dtype)
return hidden_states return hidden_states

View File

@@ -13,10 +13,13 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from operator import ipow
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import torch import torch
from torch._prims_common import is_low_precision_dtype
import torch.nn as nn import torch.nn as nn
from transformers.tokenization_utils_base import import_protobuf_decode_error
from ...configuration_utils import ConfigMixin, register_to_config from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging from ...utils import is_torch_version, logging
@@ -26,10 +29,110 @@ from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm from ..normalization import (
AdaLayerNormContinuous,
LuminaLayerNormContinuous,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-n
class FP32ModulatedRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
def forward(self, hidden_states, scale=None):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps)
if scale is not None:
hidden_states = hidden_states * scale
return hidden_states
class MochiLayerNormContinuous(nn.Module):
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
# However, this is how it was implemented in the original code, and it's rather likely you should
# set `elementwise_affine` to False.
elementwise_affine=True,
eps=1e-5,
bias=True,
norm_type="layer_norm",
out_dim: Optional[int] = None,
):
super().__init__()
# AdaLN
self.silu = nn.SiLU()
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
self.linear_2 = None
if out_dim is not None:
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
def forward(
self,
x: torch.Tensor,
conditioning_embedding: torch.Tensor,
) -> torch.Tensor:
output_dtype = x.dtype
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
scale = emb
x = self.norm(x, (1 + scale.unsqueeze(1).float()))
if self.linear_2 is not None:
x = self.linear_2(x)
return x.to(output_dtype)
class MochiRMSNormZero(nn.Module):
r"""
Adaptive RMS Norm used in Mochi.
Parameters:
embedding_dim (`int`): The size of each embedding vector.
"""
def __init__(
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
) -> None:
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, hidden_dim)
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
def forward(
self, hidden_states: torch.Tensor, emb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
hidden_states_dtype = hidden_states.dtype
emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].float()))
hidden_states = hidden_states.to(hidden_states_dtype)
return hidden_states, gate_msa, scale_mlp, gate_mlp
@maybe_allow_in_graph @maybe_allow_in_graph
@@ -76,7 +179,7 @@ class MochiTransformerBlock(nn.Module):
if not context_pre_only: if not context_pre_only:
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
else: else:
self.norm1_context = LuminaLayerNormContinuous( self.norm1_context = MochiLayerNormContinuous(
embedding_dim=pooled_projection_dim, embedding_dim=pooled_projection_dim,
conditioning_embedding_dim=dim, conditioning_embedding_dim=dim,
eps=eps, eps=eps,
@@ -98,16 +201,16 @@ class MochiTransformerBlock(nn.Module):
out_context_dim=pooled_projection_dim, out_context_dim=pooled_projection_dim,
context_pre_only=context_pre_only, context_pre_only=context_pre_only,
processor=MochiAttnProcessor2_0(), processor=MochiAttnProcessor2_0(),
eps=eps, eps=1e-5,
elementwise_affine=True, elementwise_affine=True,
) )
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm2 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm2_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm3 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm3_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None self.ff_context = None
@@ -119,8 +222,8 @@ class MochiTransformerBlock(nn.Module):
bias=False, bias=False,
) )
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm4 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm4_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
def forward( def forward(
self, self,
@@ -128,6 +231,7 @@ class MochiTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
joint_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
@@ -142,22 +246,29 @@ class MochiTransformerBlock(nn.Module):
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=joint_attention_mask,
) )
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)).to(
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) hidden_states.dtype
)
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())).to(hidden_states.dtype)
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)).to(
hidden_states.dtype
)
if not self.context_pre_only: if not self.context_pre_only:
encoder_hidden_states = encoder_hidden_states + self.norm2_context( encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
) * torch.tanh(enc_gate_msa).unsqueeze(1) ).to(encoder_hidden_states.dtype)
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) norm_encoder_hidden_states = self.norm3_context(
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float())
).to(encoder_hidden_states.dtype)
context_ff_output = self.ff_context(norm_encoder_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( encoder_hidden_states = encoder_hidden_states + self.norm4_context(
enc_gate_mlp context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
).unsqueeze(1) ).to(encoder_hidden_states.dtype)
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
@@ -202,7 +313,8 @@ class MochiRoPE(nn.Module):
return positions return positions
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor: def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float()) with torch.autocast("cuda", enabled=False):
freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs)
freqs_cos = torch.cos(freqs) freqs_cos = torch.cos(freqs)
freqs_sin = torch.sin(freqs) freqs_sin = torch.sin(freqs)
return freqs_cos, freqs_sin return freqs_cos, freqs_sin
@@ -308,7 +420,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
) )
self.norm_out = AdaLayerNormContinuous( self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" inner_dim,
inner_dim,
elementwise_affine=False,
eps=1e-6,
norm_type="layer_norm",
) )
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
@@ -324,6 +440,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor, timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
joint_attention_mask=None,
return_dict: bool = True, return_dict: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape batch_size, num_channels, num_frames, height, width = hidden_states.shape
@@ -333,7 +450,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
post_patch_width = width // p post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed( temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
) )
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -373,8 +493,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
joint_attention_mask=joint_attention_mask,
) )
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)

View File

@@ -17,10 +17,11 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5TokenizerFast from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKLMochi
from ...models.transformers import MochiTransformer3DModel from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
@@ -55,7 +56,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe.enable_model_cpu_offload() >>> pipe.enable_model_cpu_offload()
>>> pipe.enable_vae_tiling() >>> pipe.enable_vae_tiling()
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] >>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0]
>>> export_to_video(frames, "mochi.mp4") >>> export_to_video(frames, "mochi.mp4")
``` ```
""" """
@@ -163,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
Conditional Transformer architecture to denoise the encoded video latents. Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]): scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents. A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]): vae ([`AutoencoderKLMochi`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]): text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -183,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL, vae: AutoencoderKLMochi,
text_encoder: T5EncoderModel, text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast, tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel, transformer: MochiTransformer3DModel,
@@ -197,17 +198,11 @@ class MochiPipeline(DiffusionPipeline):
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
) )
# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8
self.tokenizer_max_length = ( self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.default_height = 480
self.default_width = 848
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
@@ -245,7 +240,7 @@ class MochiPipeline(DiffusionPipeline):
f" {max_sequence_length} tokens: {removed_text}" f" {max_sequence_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -340,7 +335,12 @@ class MochiPipeline(DiffusionPipeline):
dtype=dtype, dtype=dtype,
) )
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask return (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
)
def check_inputs( def check_inputs(
self, self,
@@ -424,6 +424,13 @@ class MochiPipeline(DiffusionPipeline):
""" """
self.vae.disable_tiling() self.vae.disable_tiling()
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
num_latents = latent_frames * latent_height * latent_width
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
return mask
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
@@ -436,9 +443,9 @@ class MochiPipeline(DiffusionPipeline):
generator, generator,
latents=None, latents=None,
): ):
height = height // self.vae_spatial_scale_factor height = height // self.vae_scale_factor_spatial
width = width // self.vae_spatial_scale_factor width = width // self.vae_scale_factor_spatial
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (batch_size, num_channels_latents, num_frames, height, width) shape = (batch_size, num_channels_latents, num_frames, height, width)
@@ -478,7 +485,7 @@ class MochiPipeline(DiffusionPipeline):
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_frames: int = 19, num_frames: int = 19,
num_inference_steps: int = 28, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
guidance_scale: float = 4.5, guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1, num_videos_per_prompt: Optional[int] = 1,
@@ -501,13 +508,13 @@ class MochiPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
height (`int`, *optional*, defaults to `self.default_height`): height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`):
The height in pixels of the generated image. This is set to 480 by default for the best results. The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to `self.default_width`): width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`):
The width in pixels of the generated image. This is set to 848 by default for the best results. The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `19`): num_frames (`int`, defaults to `19`):
The number of video frames to generate The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
@@ -567,8 +574,8 @@ class MochiPipeline(DiffusionPipeline):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.default_height height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
width = width or self.default_width width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
@@ -594,28 +601,28 @@ class MochiPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
with torch.autocast("cuda", torch.float32):
# 3. Prepare text embeddings # 3. Prepare text embeddings
( (
prompt_embeds, prompt_embeds,
prompt_attention_mask, prompt_attention_mask,
negative_prompt_embeds, negative_prompt_embeds,
negative_prompt_attention_mask, negative_prompt_attention_mask,
) = self.encode_prompt( ) = self.encode_prompt(
prompt=prompt, prompt=prompt,
negative_prompt=negative_prompt, negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance, do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt, num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds, prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds, negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask, prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask, negative_prompt_attention_mask=negative_prompt_attention_mask,
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
device=device, device=device,
) )
if self.do_classifier_free_guidance: # if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels num_channels_latents = self.transformer.config.in_channels
@@ -637,6 +644,9 @@ class MochiPipeline(DiffusionPipeline):
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas) sigmas = np.array(sigmas)
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
num_inference_steps, num_inference_steps,
@@ -653,25 +663,41 @@ class MochiPipeline(DiffusionPipeline):
if self.interrupt: if self.interrupt:
continue continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
latent_model_input = latents
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer( noise_pred_text = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
encoder_attention_mask=prompt_attention_mask, encoder_attention_mask=prompt_attention_mask,
joint_attention_mask=joint_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond = self.transformer(
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
encoder_attention_mask=negative_prompt_attention_mask,
joint_attention_mask=negative_joint_attention_mask,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond.float() + self.guidance_scale * (
noise_pred_text.float() - noise_pred_uncond.float()
)
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype latents_dtype = latents.dtype
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0]
latents = latents.to(latents_dtype)
if latents.dtype != latents_dtype: if latents.dtype != latents_dtype:
if torch.backends.mps.is_available(): if torch.backends.mps.is_available():
@@ -693,27 +719,33 @@ class MochiPipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else:
# unscale/denormalize the latents with torch.autocast("cuda", torch.float32):
# denormalize with the mean and std if available and not None # unscale/denormalize the latents
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None # denormalize with the mean and std if available and not None
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None has_latents_mean = (
if has_latents_mean and has_latents_std: hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
latents_mean = (
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
) )
latents_std = ( has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype) if has_latents_mean and has_latents_std:
) latents_mean = (
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean torch.tensor(self.vae.config.latents_mean)
else: .view(1, 12, 1, 1, 1)
latents = latents / self.vae.config.scaling_factor .to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, 12, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
else:
latents = latents / self.vae.config.scaling_factor
video = self.vae.decode(latents, return_dict=False)[0] video = self.vae.decode(latents, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type) video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models # Offload all models
self.maybe_free_model_hooks() self.maybe_free_model_hooks()