mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-15 08:54:20 +08:00
Compare commits
6 Commits
cache-docs
...
mochi-drop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fcc59d01a9 | ||
|
|
21b09979dc | ||
|
|
79380ca719 | ||
|
|
10275feacd | ||
|
|
30dd9f6845 | ||
|
|
27f81bd54f |
@@ -3572,16 +3572,36 @@ class MochiAttnProcessor2_0:
|
|||||||
encoder_value.transpose(1, 2),
|
encoder_value.transpose(1, 2),
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_length = query.size(2)
|
batch_size, heads, sequence_length, dim = query.shape
|
||||||
encoder_sequence_length = encoder_query.size(2)
|
encoder_sequence_length = encoder_query.shape[2]
|
||||||
|
total_length = sequence_length + encoder_sequence_length
|
||||||
|
|
||||||
query = torch.cat([query, encoder_query], dim=2)
|
query = torch.cat([query, encoder_query], dim=2)
|
||||||
key = torch.cat([key, encoder_key], dim=2)
|
key = torch.cat([key, encoder_key], dim=2)
|
||||||
value = torch.cat([value, encoder_value], dim=2)
|
value = torch.cat([value, encoder_value], dim=2)
|
||||||
|
|
||||||
|
# Zero out tokens based on the attention mask
|
||||||
|
# query = query * attention_mask[:, None, :, None]
|
||||||
|
# key = key * attention_mask[:, None, :, None]
|
||||||
|
# value = value * attention_mask[:, None, :, None]
|
||||||
|
|
||||||
|
query = query.view(1, query.size(1), -1, query.size(-1))
|
||||||
|
key = key.view(1, key.size(1), -1, key.size(-1))
|
||||||
|
value = value.view(1, value.size(1), -1, key.size(-1))
|
||||||
|
|
||||||
|
select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
||||||
|
|
||||||
|
query = torch.index_select(query, 2, select_index)
|
||||||
|
key = torch.index_select(key, 2, select_index)
|
||||||
|
value = torch.index_select(value, 2, select_index)
|
||||||
|
|
||||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
output = torch.zeros(
|
||||||
|
batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype
|
||||||
|
)
|
||||||
|
output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states)
|
||||||
|
hidden_states = output.view(batch_size, total_length, dim * heads)
|
||||||
|
|
||||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||||
(sequence_length, encoder_sequence_length), dim=1
|
(sequence_length, encoder_sequence_length), dim=1
|
||||||
|
|||||||
@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
|
|||||||
height, width = latent.shape[-2:]
|
height, width = latent.shape[-2:]
|
||||||
else:
|
else:
|
||||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||||
|
|
||||||
latent = self.proj(latent)
|
latent = self.proj(latent)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
|||||||
@@ -256,7 +256,9 @@ class MochiRMSNormZero(nn.Module):
|
|||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
emb = self.linear(self.silu(emb))
|
emb = self.linear(self.silu(emb))
|
||||||
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
|
||||||
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])
|
scale_msa = scale_msa.float()
|
||||||
|
_hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None])
|
||||||
|
hidden_states = _hidden_states.to(hidden_states.dtype)
|
||||||
|
|
||||||
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
return hidden_states, gate_msa, scale_mlp, gate_mlp
|
||||||
|
|
||||||
@@ -538,7 +540,7 @@ class RMSNorm(nn.Module):
|
|||||||
hidden_states = hidden_states.to(self.weight.dtype)
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
hidden_states = hidden_states * self.weight
|
hidden_states = hidden_states * self.weight
|
||||||
else:
|
else:
|
||||||
hidden_states = hidden_states.to(input_dtype)
|
hidden_states = hidden_states # .to(input_dtype)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import numbers
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -26,12 +27,50 @@ from ..attention_processor import Attention, MochiAttnProcessor2_0
|
|||||||
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
|
||||||
from ..modeling_outputs import Transformer2DModelOutput
|
from ..modeling_outputs import Transformer2DModelOutput
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm
|
from ..normalization import (
|
||||||
|
AdaLayerNormContinuous,
|
||||||
|
LuminaLayerNormContinuous,
|
||||||
|
MochiRMSNormZero,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
class MochiRMSNorm(nn.Module):
|
||||||
|
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
|
if isinstance(dim, numbers.Integral):
|
||||||
|
dim = (dim,)
|
||||||
|
|
||||||
|
self.dim = torch.Size(dim)
|
||||||
|
|
||||||
|
if elementwise_affine:
|
||||||
|
self.weight = nn.Parameter(torch.ones(dim))
|
||||||
|
else:
|
||||||
|
self.weight = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states, scale=None):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||||
|
if scale is not None:
|
||||||
|
hidden_states = hidden_states * scale
|
||||||
|
|
||||||
|
if self.weight is not None:
|
||||||
|
# convert into half-precision if necessary
|
||||||
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||||
|
hidden_states = hidden_states.to(self.weight.dtype)
|
||||||
|
hidden_states = hidden_states * self.weight
|
||||||
|
else:
|
||||||
|
hidden_states = hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@maybe_allow_in_graph
|
@maybe_allow_in_graph
|
||||||
class MochiTransformerBlock(nn.Module):
|
class MochiTransformerBlock(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@@ -103,11 +142,11 @@ class MochiTransformerBlock(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
|
||||||
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
self.norm2 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
self.norm2_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||||
|
|
||||||
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
self.norm3 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
self.norm3_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||||
|
|
||||||
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
|
||||||
self.ff_context = None
|
self.ff_context = None
|
||||||
@@ -119,8 +158,8 @@ class MochiTransformerBlock(nn.Module):
|
|||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False)
|
self.norm4 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||||
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
self.norm4_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -128,6 +167,7 @@ class MochiTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
temb: torch.Tensor,
|
temb: torch.Tensor,
|
||||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||||
|
joint_attention_mask=None,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||||
|
|
||||||
@@ -136,28 +176,45 @@ class MochiTransformerBlock(nn.Module):
|
|||||||
encoder_hidden_states, temb
|
encoder_hidden_states, temb
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb).to(
|
||||||
|
encoder_hidden_states.dtype
|
||||||
|
)
|
||||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||||
hidden_states=norm_hidden_states,
|
hidden_states=norm_hidden_states,
|
||||||
encoder_hidden_states=norm_encoder_hidden_states,
|
encoder_hidden_states=norm_encoder_hidden_states,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
attention_mask=joint_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
# hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
# norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
|
||||||
|
# ff_output = self.ff(norm_hidden_states)
|
||||||
|
# hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
|
||||||
|
|
||||||
|
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
|
||||||
|
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float()))
|
||||||
ff_output = self.ff(norm_hidden_states)
|
ff_output = self.ff(norm_hidden_states)
|
||||||
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
|
hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
|
||||||
|
|
||||||
if not self.context_pre_only:
|
if not self.context_pre_only:
|
||||||
|
# encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||||
|
# context_attn_hidden_states
|
||||||
|
# ) * torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||||
|
# norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
|
||||||
|
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||||
|
# encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
|
||||||
|
# enc_gate_mlp
|
||||||
|
# ).unsqueeze(1)
|
||||||
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
|
||||||
context_attn_hidden_states
|
context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
|
||||||
) * torch.tanh(enc_gate_msa).unsqueeze(1)
|
)
|
||||||
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
|
norm_encoder_hidden_states = self.norm3_context(
|
||||||
|
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float())
|
||||||
|
)
|
||||||
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
||||||
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
|
encoder_hidden_states = encoder_hidden_states + self.norm4_context(
|
||||||
enc_gate_mlp
|
context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
|
||||||
).unsqueeze(1)
|
)
|
||||||
|
|
||||||
return hidden_states, encoder_hidden_states
|
return hidden_states, encoder_hidden_states
|
||||||
|
|
||||||
@@ -308,7 +365,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.norm_out = AdaLayerNormContinuous(
|
self.norm_out = AdaLayerNormContinuous(
|
||||||
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
|
inner_dim,
|
||||||
|
inner_dim,
|
||||||
|
elementwise_affine=False,
|
||||||
|
eps=1e-6,
|
||||||
|
norm_type="layer_norm",
|
||||||
)
|
)
|
||||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
|
||||||
|
|
||||||
@@ -324,6 +385,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
encoder_hidden_states: torch.Tensor,
|
encoder_hidden_states: torch.Tensor,
|
||||||
timestep: torch.LongTensor,
|
timestep: torch.LongTensor,
|
||||||
encoder_attention_mask: torch.Tensor,
|
encoder_attention_mask: torch.Tensor,
|
||||||
|
joint_attention_mask=None,
|
||||||
return_dict: bool = True,
|
return_dict: bool = True,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||||
@@ -333,7 +395,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
post_patch_width = width // p
|
post_patch_width = width // p
|
||||||
|
|
||||||
temb, encoder_hidden_states = self.time_embed(
|
temb, encoder_hidden_states = self.time_embed(
|
||||||
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype
|
timestep,
|
||||||
|
encoder_hidden_states,
|
||||||
|
encoder_attention_mask,
|
||||||
|
hidden_dtype=hidden_states.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||||
@@ -373,8 +438,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
|||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
temb=temb,
|
temb=temb,
|
||||||
image_rotary_emb=image_rotary_emb,
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
joint_attention_mask=joint_attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = self.norm_out(hidden_states, temb)
|
hidden_states = self.norm_out(hidden_states, temb)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,11 @@ from typing import Callable, Dict, List, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from transformers import T5EncoderModel, T5TokenizerFast
|
from transformers import T5EncoderModel, T5TokenizerFast
|
||||||
|
|
||||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||||
from ...models.autoencoders import AutoencoderKL
|
from ...models.autoencoders import AutoencoderKLMochi
|
||||||
from ...models.transformers import MochiTransformer3DModel
|
from ...models.transformers import MochiTransformer3DModel
|
||||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||||
from ...utils import (
|
from ...utils import (
|
||||||
@@ -55,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
|||||||
>>> pipe.enable_model_cpu_offload()
|
>>> pipe.enable_model_cpu_offload()
|
||||||
>>> pipe.enable_vae_tiling()
|
>>> pipe.enable_vae_tiling()
|
||||||
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
|
||||||
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0]
|
>>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0]
|
||||||
>>> export_to_video(frames, "mochi.mp4")
|
>>> export_to_video(frames, "mochi.mp4")
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -163,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
Conditional Transformer architecture to denoise the encoded video latents.
|
Conditional Transformer architecture to denoise the encoded video latents.
|
||||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||||
vae ([`AutoencoderKL`]):
|
vae ([`AutoencoderKLMochi`]):
|
||||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||||
text_encoder ([`T5EncoderModel`]):
|
text_encoder ([`T5EncoderModel`]):
|
||||||
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
|
||||||
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
|
||||||
@@ -183,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||||
vae: AutoencoderKL,
|
vae: AutoencoderKLMochi,
|
||||||
text_encoder: T5EncoderModel,
|
text_encoder: T5EncoderModel,
|
||||||
tokenizer: T5TokenizerFast,
|
tokenizer: T5TokenizerFast,
|
||||||
transformer: MochiTransformer3DModel,
|
transformer: MochiTransformer3DModel,
|
||||||
@@ -197,17 +198,11 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
transformer=transformer,
|
transformer=transformer,
|
||||||
scheduler=scheduler,
|
scheduler=scheduler,
|
||||||
)
|
)
|
||||||
# TODO: determine these scaling factors from model parameters
|
|
||||||
self.vae_spatial_scale_factor = 8
|
|
||||||
self.vae_temporal_scale_factor = 6
|
|
||||||
self.patch_size = 2
|
|
||||||
|
|
||||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor)
|
self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8
|
||||||
self.tokenizer_max_length = (
|
self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6
|
||||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
|
||||||
)
|
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
|
||||||
self.default_height = 480
|
|
||||||
self.default_width = 848
|
|
||||||
|
|
||||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
|
||||||
def _get_t5_prompt_embeds(
|
def _get_t5_prompt_embeds(
|
||||||
@@ -245,7 +240,7 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
f" {max_sequence_length} tokens: {removed_text}"
|
f" {max_sequence_length} tokens: {removed_text}"
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||||
|
|
||||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||||
@@ -340,7 +335,12 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
return (
|
||||||
|
prompt_embeds,
|
||||||
|
prompt_attention_mask,
|
||||||
|
negative_prompt_embeds,
|
||||||
|
negative_prompt_attention_mask,
|
||||||
|
)
|
||||||
|
|
||||||
def check_inputs(
|
def check_inputs(
|
||||||
self,
|
self,
|
||||||
@@ -424,6 +424,13 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
"""
|
"""
|
||||||
self.vae.disable_tiling()
|
self.vae.disable_tiling()
|
||||||
|
|
||||||
|
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
|
||||||
|
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
|
||||||
|
num_latents = latent_frames * latent_height * latent_width
|
||||||
|
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
|
||||||
|
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
|
||||||
|
return mask
|
||||||
|
|
||||||
def prepare_latents(
|
def prepare_latents(
|
||||||
self,
|
self,
|
||||||
batch_size,
|
batch_size,
|
||||||
@@ -436,9 +443,9 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
generator,
|
generator,
|
||||||
latents=None,
|
latents=None,
|
||||||
):
|
):
|
||||||
height = height // self.vae_spatial_scale_factor
|
height = height // self.vae_scale_factor_spatial
|
||||||
width = width // self.vae_spatial_scale_factor
|
width = width // self.vae_scale_factor_spatial
|
||||||
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1
|
num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
||||||
|
|
||||||
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
shape = (batch_size, num_channels_latents, num_frames, height, width)
|
||||||
|
|
||||||
@@ -478,7 +485,7 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
height: Optional[int] = None,
|
height: Optional[int] = None,
|
||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
num_frames: int = 19,
|
num_frames: int = 19,
|
||||||
num_inference_steps: int = 28,
|
num_inference_steps: int = 50,
|
||||||
timesteps: List[int] = None,
|
timesteps: List[int] = None,
|
||||||
guidance_scale: float = 4.5,
|
guidance_scale: float = 4.5,
|
||||||
num_videos_per_prompt: Optional[int] = 1,
|
num_videos_per_prompt: Optional[int] = 1,
|
||||||
@@ -501,13 +508,13 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
prompt (`str` or `List[str]`, *optional*):
|
prompt (`str` or `List[str]`, *optional*):
|
||||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||||
instead.
|
instead.
|
||||||
height (`int`, *optional*, defaults to `self.default_height`):
|
height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`):
|
||||||
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
The height in pixels of the generated image. This is set to 480 by default for the best results.
|
||||||
width (`int`, *optional*, defaults to `self.default_width`):
|
width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`):
|
||||||
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
The width in pixels of the generated image. This is set to 848 by default for the best results.
|
||||||
num_frames (`int`, defaults to `19`):
|
num_frames (`int`, defaults to `19`):
|
||||||
The number of video frames to generate
|
The number of video frames to generate
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||||
expense of slower inference.
|
expense of slower inference.
|
||||||
timesteps (`List[int]`, *optional*):
|
timesteps (`List[int]`, *optional*):
|
||||||
@@ -567,8 +574,8 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||||
|
|
||||||
height = height or self.default_height
|
height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
|
||||||
width = width or self.default_width
|
width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
|
||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
# 1. Check inputs. Raise error if not correct
|
||||||
self.check_inputs(
|
self.check_inputs(
|
||||||
@@ -594,7 +601,6 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
batch_size = prompt_embeds.shape[0]
|
batch_size = prompt_embeds.shape[0]
|
||||||
|
|
||||||
device = self._execution_device
|
device = self._execution_device
|
||||||
|
|
||||||
# 3. Prepare text embeddings
|
# 3. Prepare text embeddings
|
||||||
(
|
(
|
||||||
prompt_embeds,
|
prompt_embeds,
|
||||||
@@ -613,9 +619,9 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
max_sequence_length=max_sequence_length,
|
max_sequence_length=max_sequence_length,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
if self.do_classifier_free_guidance:
|
# if self.do_classifier_free_guidance:
|
||||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
# prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
# prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||||
|
|
||||||
# 4. Prepare latent variables
|
# 4. Prepare latent variables
|
||||||
num_channels_latents = self.transformer.config.in_channels
|
num_channels_latents = self.transformer.config.in_channels
|
||||||
@@ -637,6 +643,9 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
|
||||||
sigmas = np.array(sigmas)
|
sigmas = np.array(sigmas)
|
||||||
|
|
||||||
|
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
|
||||||
|
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
|
||||||
|
|
||||||
timesteps, num_inference_steps = retrieve_timesteps(
|
timesteps, num_inference_steps = retrieve_timesteps(
|
||||||
self.scheduler,
|
self.scheduler,
|
||||||
num_inference_steps,
|
num_inference_steps,
|
||||||
@@ -653,21 +662,34 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
if self.interrupt:
|
if self.interrupt:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
# latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
||||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
# # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||||
|
# timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
|
latent_model_input = latents
|
||||||
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
||||||
|
|
||||||
noise_pred = self.transformer(
|
noise_pred_text = self.transformer(
|
||||||
hidden_states=latent_model_input,
|
hidden_states=latent_model_input,
|
||||||
encoder_hidden_states=prompt_embeds,
|
encoder_hidden_states=prompt_embeds,
|
||||||
timestep=timestep,
|
timestep=timestep,
|
||||||
encoder_attention_mask=prompt_attention_mask,
|
encoder_attention_mask=prompt_attention_mask,
|
||||||
|
joint_attention_mask=joint_attention_mask,
|
||||||
return_dict=False,
|
return_dict=False,
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
if self.do_classifier_free_guidance:
|
if self.do_classifier_free_guidance:
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
noise_pred_uncond = self.transformer(
|
||||||
|
hidden_states=latent_model_input,
|
||||||
|
encoder_hidden_states=negative_prompt_embeds,
|
||||||
|
timestep=timestep,
|
||||||
|
encoder_attention_mask=negative_prompt_attention_mask,
|
||||||
|
joint_attention_mask=negative_joint_attention_mask,
|
||||||
|
return_dict=False,
|
||||||
|
)[0]
|
||||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||||
|
else:
|
||||||
|
noise_pred = noise_pred_text
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
# compute the previous noisy sample x_t -> x_t-1
|
||||||
latents_dtype = latents.dtype
|
latents_dtype = latents.dtype
|
||||||
@@ -693,7 +715,6 @@ class MochiPipeline(DiffusionPipeline):
|
|||||||
|
|
||||||
if XLA_AVAILABLE:
|
if XLA_AVAILABLE:
|
||||||
xm.mark_step()
|
xm.mark_step()
|
||||||
|
|
||||||
if output_type == "latent":
|
if output_type == "latent":
|
||||||
video = latents
|
video = latents
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user