mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-08 21:44:27 +08:00
Compare commits
10 Commits
attn-refac
...
mochi-tran
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
dded24364c | ||
|
|
3ffa711db1 | ||
|
|
66a5f59ca1 | ||
|
|
1782d0241a | ||
|
|
fcc59d01a9 | ||
|
|
21b09979dc | ||
|
|
79380ca719 | ||
|
|
10275feacd | ||
|
|
30dd9f6845 | ||
|
|
27f81bd54f |
@@ -16,6 +16,7 @@ import math
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch._higher_order_ops.flex_attention import sdpa_dense
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
@@ -3554,11 +3555,11 @@ class MochiAttnProcessor2_0:
|
||||
if image_rotary_emb is not None:
|
||||
|
||||
def apply_rotary_emb(x, freqs_cos, freqs_sin):
|
||||
x_even = x[..., 0::2].float()
|
||||
x_odd = x[..., 1::2].float()
|
||||
x_even = x[..., 0::2]
|
||||
x_odd = x[..., 1::2]
|
||||
|
||||
cos = (x_even * freqs_cos - x_odd * freqs_sin).to(x.dtype)
|
||||
sin = (x_even * freqs_sin + x_odd * freqs_cos).to(x.dtype)
|
||||
cos = (x_even * freqs_cos.float() - x_odd * freqs_sin.float()).to(x.dtype)
|
||||
sin = (x_even * freqs_sin.float() + x_odd * freqs_cos.float()).to(x.dtype)
|
||||
|
||||
return torch.stack([cos, sin], dim=-1).flatten(-2)
|
||||
|
||||
@@ -3572,16 +3573,39 @@ class MochiAttnProcessor2_0:
|
||||
encoder_value.transpose(1, 2),
|
||||
)
|
||||
|
||||
sequence_length = query.size(2)
|
||||
encoder_sequence_length = encoder_query.size(2)
|
||||
batch_size, heads, sequence_length, dim = query.shape
|
||||
encoder_sequence_length = encoder_query.shape[2]
|
||||
total_length = sequence_length + encoder_sequence_length
|
||||
|
||||
query = torch.cat([query, encoder_query], dim=2)
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# Zero out tokens based on the attention mask
|
||||
# query = query * attention_mask[:, None, :, None]
|
||||
# key = key * attention_mask[:, None, :, None]
|
||||
# value = value * attention_mask[:, None, :, None]
|
||||
|
||||
query = query.view(1, query.size(1), -1, query.size(-1))
|
||||
key = key.view(1, key.size(1), -1, key.size(-1))
|
||||
value = value.view(1, value.size(1), -1, key.size(-1))
|
||||
|
||||
select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||
|
||||
query = torch.index_select(query, 2, select_index)
|
||||
key = torch.index_select(key, 2, select_index)
|
||||
value = torch.index_select(value, 2, select_index)
|
||||
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
with sdpa_kernel([SDPBackend.EFFICIENT_ATTENTION]):
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0)
|
||||
output = torch.zeros(
|
||||
batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states)
|
||||
hidden_states = output.view(batch_size, total_length, dim * heads)
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
|
||||
@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
@@ -256,7 +256,9 @@ class MochiRMSNormZero(nn.Module):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
|
||||
scale_msa = scale_msa.float()
|
||||
_hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None])
|
||||
hidden_states = _hidden_states.to(hidden_states.dtype)
|
||||
|
||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
@@ -530,15 +532,14 @@ class RMSNorm(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -13,10 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from operator import ipow
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._prims_common import is_low_precision_dtype
|
||||
import torch.nn as nn
|
||||
from transformers.tokenization_utils_base import import_protobuf_decode_error
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import is_torch_version, logging
|
||||
@@ -26,10 +29,110 @@ from ..attention_processor import Attention, MochiAttnProcessor2_0
|
||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
||||
from ..normalization import (
|
||||
AdaLayerNormContinuous,
|
||||
LuminaLayerNormContinuous,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-n
|
||||
|
||||
|
||||
class FP32ModulatedRMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||
super().__init__()
|
||||
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, hidden_states, scale=None):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states.float() * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if scale is not None:
|
||||
hidden_states = hidden_states * scale
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MochiLayerNormContinuous(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
# NOTE: It is a bit weird that the norm layer can be configured to have scale and shift parameters
|
||||
# because the output is immediately scaled and shifted by the projected conditioning embeddings.
|
||||
# Note that AdaLayerNorm does not let the norm layer have scale and shift parameters.
|
||||
# However, this is how it was implemented in the original code, and it's rather likely you should
|
||||
# set `elementwise_affine` to False.
|
||||
elementwise_affine=True,
|
||||
eps=1e-5,
|
||||
bias=True,
|
||||
norm_type="layer_norm",
|
||||
out_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# AdaLN
|
||||
self.silu = nn.SiLU()
|
||||
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
||||
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
self.linear_2 = None
|
||||
if out_dim is not None:
|
||||
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
conditioning_embedding: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
output_dtype = x.dtype
|
||||
# convert back to the original dtype in case `conditioning_embedding`` is upcasted to float32 (needed for hunyuanDiT)
|
||||
emb = self.linear_1(self.silu(conditioning_embedding).to(x.dtype))
|
||||
scale = emb
|
||||
x = self.norm(x, (1 + scale.unsqueeze(1).float()))
|
||||
|
||||
if self.linear_2 is not None:
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x.to(output_dtype)
|
||||
|
||||
|
||||
class MochiRMSNormZero(nn.Module):
|
||||
r"""
|
||||
Adaptive RMS Norm used in Mochi.
|
||||
|
||||
Parameters:
|
||||
embedding_dim (`int`): The size of each embedding vector.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, embedding_dim: int, hidden_dim: int, eps: float = 1e-5, elementwise_affine: bool = False
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, hidden_dim)
|
||||
self.norm = FP32ModulatedRMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, emb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||
|
||||
hidden_states = self.norm(hidden_states, (1 + scale_msa[:, None].float()))
|
||||
hidden_states = hidden_states.to(hidden_states_dtype)
|
||||
|
||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
@@ -76,7 +179,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
if not context_pre_only:
|
||||
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
else:
|
||||
self.norm1_context = LuminaLayerNormContinuous(
|
||||
self.norm1_context = MochiLayerNormContinuous(
|
||||
embedding_dim=pooled_projection_dim,
|
||||
conditioning_embedding_dim=dim,
|
||||
eps=eps,
|
||||
@@ -98,16 +201,16 @@ class MochiTransformerBlock(nn.Module):
|
||||
out_context_dim=pooled_projection_dim,
|
||||
context_pre_only=context_pre_only,
|
||||
processor=MochiAttnProcessor2_0(),
|
||||
eps=eps,
|
||||
eps=1e-5,
|
||||
elementwise_affine=True,
|
||||
)
|
||||
|
||||
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm2_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm3_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
||||
self.ff_context = None
|
||||
@@ -119,8 +222,8 @@ class MochiTransformerBlock(nn.Module):
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4 = FP32ModulatedRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.norm4_context = FP32ModulatedRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -128,6 +231,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
joint_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
@@ -142,22 +246,29 @@ class MochiTransformerBlock(nn.Module):
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=joint_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1)).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float())).to(hidden_states.dtype)
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
|
||||
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1)).to(
|
||||
hidden_states.dtype
|
||||
)
|
||||
|
||||
if not self.context_pre_only:
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||
context_attn_hidden_states
|
||||
) * torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
|
||||
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||
).to(encoder_hidden_states.dtype)
|
||||
norm_encoder_hidden_states = self.norm3_context(
|
||||
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float())
|
||||
).to(encoder_hidden_states.dtype)
|
||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
|
||||
enc_gate_mlp
|
||||
).unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||
).to(encoder_hidden_states.dtype)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
@@ -202,7 +313,8 @@ class MochiRoPE(nn.Module):
|
||||
return positions
|
||||
|
||||
def _create_rope(self, freqs: torch.Tensor, pos: torch.Tensor) -> torch.Tensor:
|
||||
freqs = torch.einsum("nd,dhf->nhf", pos, freqs.float())
|
||||
with torch.autocast("cuda", enabled=False):
|
||||
freqs = torch.einsum("nd,dhf->nhf", pos.to(freqs), freqs)
|
||||
freqs_cos = torch.cos(freqs)
|
||||
freqs_sin = torch.sin(freqs)
|
||||
return freqs_cos, freqs_sin
|
||||
@@ -308,7 +420,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
self.norm_out = AdaLayerNormContinuous(
|
||||
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
|
||||
inner_dim,
|
||||
inner_dim,
|
||||
elementwise_affine=False,
|
||||
eps=1e-6,
|
||||
norm_type="layer_norm",
|
||||
)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||
|
||||
@@ -324,6 +440,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
joint_attention_mask=None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
@@ -333,7 +450,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
post_patch_width = width // p
|
||||
|
||||
temb, encoder_hidden_states = self.time_embed(
|
||||
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
|
||||
timestep,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
@@ -373,8 +493,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_mask=joint_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
|
||||
@@ -17,10 +17,11 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...models.autoencoders import AutoencoderKL
|
||||
from ...models.autoencoders import AutoencoderKLMochi
|
||||
from ...models.transformers import MochiTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
@@ -55,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
>>> pipe.enable_vae_tiling()
|
||||
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
|
||||
>>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0]
|
||||
>>> export_to_video(frames, "mochi.mp4")
|
||||
```
|
||||
"""
|
||||
@@ -163,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
|
||||
Conditional Transformer architecture to denoise the encoded video latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
vae ([`AutoencoderKLMochi`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||
@@ -183,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
vae: AutoencoderKLMochi,
|
||||
text_encoder: T5EncoderModel,
|
||||
tokenizer: T5TokenizerFast,
|
||||
transformer: MochiTransformer3DModel,
|
||||
@@ -197,17 +198,11 @@ class MochiPipeline(DiffusionPipeline):
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
# TODO: determine these scaling factors from model parameters
|
||||
self.vae_spatial_scale_factor = 8
|
||||
self.vae_temporal_scale_factor = 6
|
||||
self.patch_size = 2
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
||||
self.tokenizer_max_length = (
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_height = 480
|
||||
self.default_width = 848
|
||||
self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8
|
||||
self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6
|
||||
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||
|
||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
||||
def _get_t5_prompt_embeds(
|
||||
@@ -245,7 +240,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -340,7 +335,12 @@ class MochiPipeline(DiffusionPipeline):
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
return (
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
)
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -424,6 +424,13 @@ class MochiPipeline(DiffusionPipeline):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
|
||||
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
|
||||
num_latents = latent_frames * latent_height * latent_width
|
||||
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
|
||||
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
|
||||
return mask
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -436,9 +443,9 @@ class MochiPipeline(DiffusionPipeline):
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = height // self.vae_spatial_scale_factor
|
||||
width = width // self.vae_spatial_scale_factor
|
||||
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
||||
height = height // self.vae_scale_factor_spatial
|
||||
width = width // self.vae_scale_factor_spatial
|
||||
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||
|
||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||
|
||||
@@ -478,7 +485,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_frames: int = 19,
|
||||
num_inference_steps: int = 28,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.5,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
@@ -501,13 +508,13 @@ class MochiPipeline(DiffusionPipeline):
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
height (`int`, *optional*, defaults to `self.default_height`):
|
||||
height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`):
|
||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||
width (`int`, *optional*, defaults to `self.default_width`):
|
||||
width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`):
|
||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||
num_frames (`int`, defaults to `19`):
|
||||
The number of video frames to generate
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
@@ -567,8 +574,8 @@ class MochiPipeline(DiffusionPipeline):
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
height = height or self.default_height
|
||||
width = width or self.default_width
|
||||
height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
|
||||
width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
@@ -594,28 +601,28 @@ class MochiPipeline(DiffusionPipeline):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
with torch.autocast("cuda", torch.float32):
|
||||
# 3. Prepare text embeddings
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds,
|
||||
negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
num_videos_per_prompt=num_videos_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
# if self.do_classifier_free_guidance:
|
||||
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
@@ -637,6 +644,9 @@ class MochiPipeline(DiffusionPipeline):
|
||||
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
||||
sigmas = np.array(sigmas)
|
||||
|
||||
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
|
||||
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
|
||||
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
@@ -653,25 +663,41 @@ class MochiPipeline(DiffusionPipeline):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||
# # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
# timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
latent_model_input = latents
|
||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||
|
||||
noise_pred = self.transformer(
|
||||
noise_pred_text = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
joint_attention_mask=joint_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=negative_prompt_attention_mask,
|
||||
joint_attention_mask=negative_joint_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_pred_uncond.float() + self.guidance_scale * (
|
||||
noise_pred_text.float() - noise_pred_uncond.float()
|
||||
)
|
||||
else:
|
||||
noise_pred = noise_pred_text
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
latents = self.scheduler.step(noise_pred, t, latents.float(), return_dict=False)[0]
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
@@ -693,27 +719,33 @@ class MochiPipeline(DiffusionPipeline):
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if output_type == "latent":
|
||||
video = latents
|
||||
else:
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
with torch.autocast("cuda", torch.float32):
|
||||
# unscale/denormalize the latents
|
||||
# denormalize with the mean and std if available and not None
|
||||
has_latents_mean = (
|
||||
hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
|
||||
if has_latents_mean and has_latents_std:
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, 12, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, 12, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
|
||||
else:
|
||||
latents = latents / self.vae.config.scaling_factor
|
||||
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
Reference in New Issue
Block a user