Compare commits

...

6 Commits

Author SHA1 Message Date
Dhruv Nair
fcc59d01a9 update 2024-11-23 17:15:18 +01:00
Dhruv Nair
21b09979dc update 2024-11-22 13:21:32 +01:00
Dhruv Nair
79380ca719 update 2024-11-20 19:41:08 +01:00
Dhruv Nair
10275feacd update 2024-11-20 13:57:41 +01:00
Dhruv Nair
30dd9f6845 update 2024-11-18 17:50:51 +01:00
Dhruv Nair
27f81bd54f update 2024-11-18 17:30:24 +01:00
5 changed files with 170 additions and 63 deletions

View File

@@ -3572,16 +3572,36 @@ class MochiAttnProcessor2_0:
encoder_value.transpose(1, 2), encoder_value.transpose(1, 2),
) )
sequence_length = query.size(2) batch_size, heads, sequence_length, dim = query.shape
encoder_sequence_length = encoder_query.size(2) encoder_sequence_length = encoder_query.shape[2]
total_length = sequence_length + encoder_sequence_length
query = torch.cat([query, encoder_query], dim=2) query = torch.cat([query, encoder_query], dim=2)
key = torch.cat([key, encoder_key], dim=2) key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2) value = torch.cat([value, encoder_value], dim=2)
# Zero out tokens based on the attention mask
# query = query * attention_mask[:, None, :, None]
# key = key * attention_mask[:, None, :, None]
# value = value * attention_mask[:, None, :, None]
query = query.view(1, query.size(1), -1, query.size(-1))
key = key.view(1, key.size(1), -1, key.size(-1))
value = value.view(1, value.size(1), -1, key.size(-1))
select_index = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
query = torch.index_select(query, 2, select_index)
key = torch.index_select(key, 2, select_index)
value = torch.index_select(value, 2, select_index)
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) hidden_states = hidden_states.transpose(1, 2).flatten(2, 3).squeeze(0)
hidden_states = hidden_states.to(query.dtype) output = torch.zeros(
batch_size * total_length, dim * heads, device=hidden_states.device, dtype=hidden_states.dtype
)
output.scatter_(0, select_index.unsqueeze(1).expand(-1, dim * heads), hidden_states)
hidden_states = output.view(batch_size, total_length, dim * heads)
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes( hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1 (sequence_length, encoder_sequence_length), dim=1

View File

@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
height, width = latent.shape[-2:] height, width = latent.shape[-2:]
else: else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent) latent = self.proj(latent)
if self.flatten: if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

View File

@@ -256,7 +256,9 @@ class MochiRMSNormZero(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
emb = self.linear(self.silu(emb)) emb = self.linear(self.silu(emb))
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1) scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None]) scale_msa = scale_msa.float()
_hidden_states = self.norm(hidden_states).float() * (1 + scale_msa[:, None])
hidden_states = _hidden_states.to(hidden_states.dtype)
return hidden_states, gate_msa, scale_mlp, gate_mlp return hidden_states, gate_msa, scale_mlp, gate_mlp
@@ -538,7 +540,7 @@ class RMSNorm(nn.Module):
hidden_states = hidden_states.to(self.weight.dtype) hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight hidden_states = hidden_states * self.weight
else: else:
hidden_states = hidden_states.to(input_dtype) hidden_states = hidden_states # .to(input_dtype)
return hidden_states return hidden_states

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numbers
from typing import Any, Dict, Optional, Tuple from typing import Any, Dict, Optional, Tuple
import torch import torch
@@ -26,12 +27,50 @@ from ..attention_processor import Attention, MochiAttnProcessor2_0
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
from ..modeling_outputs import Transformer2DModelOutput from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormContinuous, LuminaLayerNormContinuous, MochiRMSNormZero, RMSNorm from ..normalization import (
AdaLayerNormContinuous,
LuminaLayerNormContinuous,
MochiRMSNormZero,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class MochiRMSNorm(nn.Module):
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
super().__init__()
self.eps = eps
if isinstance(dim, numbers.Integral):
dim = (dim,)
self.dim = torch.Size(dim)
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(dim))
else:
self.weight = None
def forward(self, hidden_states, scale=None):
input_dtype = hidden_states.dtype
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
if scale is not None:
hidden_states = hidden_states * scale
if self.weight is not None:
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
hidden_states = hidden_states * self.weight
else:
hidden_states = hidden_states.to(input_dtype)
return hidden_states
@maybe_allow_in_graph @maybe_allow_in_graph
class MochiTransformerBlock(nn.Module): class MochiTransformerBlock(nn.Module):
r""" r"""
@@ -103,11 +142,11 @@ class MochiTransformerBlock(nn.Module):
) )
# TODO(aryan): norm_context layers are not needed when `context_pre_only` is True # TODO(aryan): norm_context layers are not needed when `context_pre_only` is True
self.norm2 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm2 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm2_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm2_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.norm3 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm3 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm3_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm3_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False) self.ff = FeedForward(dim, inner_dim=self.ff_inner_dim, activation_fn=activation_fn, bias=False)
self.ff_context = None self.ff_context = None
@@ -119,8 +158,8 @@ class MochiTransformerBlock(nn.Module):
bias=False, bias=False,
) )
self.norm4 = RMSNorm(dim, eps=eps, elementwise_affine=False) self.norm4 = MochiRMSNorm(dim, eps=eps, elementwise_affine=False)
self.norm4_context = RMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False) self.norm4_context = MochiRMSNorm(pooled_projection_dim, eps=eps, elementwise_affine=False)
def forward( def forward(
self, self,
@@ -128,6 +167,7 @@ class MochiTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
temb: torch.Tensor, temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None, image_rotary_emb: Optional[torch.Tensor] = None,
joint_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb) norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
@@ -136,28 +176,45 @@ class MochiTransformerBlock(nn.Module):
encoder_hidden_states, temb encoder_hidden_states, temb
) )
else: else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb) norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb).to(
encoder_hidden_states.dtype
)
attn_hidden_states, context_attn_hidden_states = self.attn1( attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states, hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states, encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
attention_mask=joint_attention_mask,
) )
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1) # hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1)) # norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
# ff_output = self.ff(norm_hidden_states)
# hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1)
hidden_states = hidden_states + self.norm2(attn_hidden_states, torch.tanh(gate_msa).unsqueeze(1))
norm_hidden_states = self.norm3(hidden_states, (1 + scale_mlp.unsqueeze(1).float()))
ff_output = self.ff(norm_hidden_states) ff_output = self.ff(norm_hidden_states)
hidden_states = hidden_states + self.norm4(ff_output) * torch.tanh(gate_mlp).unsqueeze(1) hidden_states = hidden_states + self.norm4(ff_output, torch.tanh(gate_mlp).unsqueeze(1))
if not self.context_pre_only: if not self.context_pre_only:
# encoder_hidden_states = encoder_hidden_states + self.norm2_context(
# context_attn_hidden_states
# ) * torch.tanh(enc_gate_msa).unsqueeze(1)
# norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
# context_ff_output = self.ff_context(norm_encoder_hidden_states)
# encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh(
# enc_gate_mlp
# ).unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + self.norm2_context( encoder_hidden_states = encoder_hidden_states + self.norm2_context(
context_attn_hidden_states context_attn_hidden_states, torch.tanh(enc_gate_msa).unsqueeze(1)
) * torch.tanh(enc_gate_msa).unsqueeze(1) )
norm_encoder_hidden_states = self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1)) norm_encoder_hidden_states = self.norm3_context(
encoder_hidden_states, (1 + enc_scale_mlp.unsqueeze(1).float())
)
context_ff_output = self.ff_context(norm_encoder_hidden_states) context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + self.norm4_context(context_ff_output) * torch.tanh( encoder_hidden_states = encoder_hidden_states + self.norm4_context(
enc_gate_mlp context_ff_output, torch.tanh(enc_gate_mlp).unsqueeze(1)
).unsqueeze(1) )
return hidden_states, encoder_hidden_states return hidden_states, encoder_hidden_states
@@ -308,7 +365,11 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
) )
self.norm_out = AdaLayerNormContinuous( self.norm_out = AdaLayerNormContinuous(
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm" inner_dim,
inner_dim,
elementwise_affine=False,
eps=1e-6,
norm_type="layer_norm",
) )
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels) self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
@@ -324,6 +385,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor, timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor, encoder_attention_mask: torch.Tensor,
joint_attention_mask=None,
return_dict: bool = True, return_dict: bool = True,
) -> torch.Tensor: ) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape batch_size, num_channels, num_frames, height, width = hidden_states.shape
@@ -333,7 +395,10 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
post_patch_width = width // p post_patch_width = width // p
temb, encoder_hidden_states = self.time_embed( temb, encoder_hidden_states = self.time_embed(
timestep, encoder_hidden_states, encoder_attention_mask, hidden_dtype=hidden_states.dtype timestep,
encoder_hidden_states,
encoder_attention_mask,
hidden_dtype=hidden_states.dtype,
) )
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1) hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
@@ -373,8 +438,8 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
temb=temb, temb=temb,
image_rotary_emb=image_rotary_emb, image_rotary_emb=image_rotary_emb,
joint_attention_mask=joint_attention_mask,
) )
hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states) hidden_states = self.proj_out(hidden_states)

View File

@@ -17,10 +17,11 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5TokenizerFast from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...models.autoencoders import AutoencoderKL from ...models.autoencoders import AutoencoderKLMochi
from ...models.transformers import MochiTransformer3DModel from ...models.transformers import MochiTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import ( from ...utils import (
@@ -55,7 +56,7 @@ EXAMPLE_DOC_STRING = """
>>> pipe.enable_model_cpu_offload() >>> pipe.enable_model_cpu_offload()
>>> pipe.enable_vae_tiling() >>> pipe.enable_vae_tiling()
>>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k." >>> prompt = "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k."
>>> frames = pipe(prompt, num_inference_steps=28, guidance_scale=3.5).frames[0] >>> frames = pipe(prompt, num_inference_steps=50, guidance_scale=3.5).frames[0]
>>> export_to_video(frames, "mochi.mp4") >>> export_to_video(frames, "mochi.mp4")
``` ```
""" """
@@ -163,8 +164,8 @@ class MochiPipeline(DiffusionPipeline):
Conditional Transformer architecture to denoise the encoded video latents. Conditional Transformer architecture to denoise the encoded video latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]): scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents. A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]): vae ([`AutoencoderKLMochi`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
text_encoder ([`T5EncoderModel`]): text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant. the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -183,7 +184,7 @@ class MochiPipeline(DiffusionPipeline):
def __init__( def __init__(
self, self,
scheduler: FlowMatchEulerDiscreteScheduler, scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL, vae: AutoencoderKLMochi,
text_encoder: T5EncoderModel, text_encoder: T5EncoderModel,
tokenizer: T5TokenizerFast, tokenizer: T5TokenizerFast,
transformer: MochiTransformer3DModel, transformer: MochiTransformer3DModel,
@@ -197,17 +198,11 @@ class MochiPipeline(DiffusionPipeline):
transformer=transformer, transformer=transformer,
scheduler=scheduler, scheduler=scheduler,
) )
# TODO: determine these scaling factors from model parameters
self.vae_spatial_scale_factor = 8
self.vae_temporal_scale_factor = 6
self.patch_size = 2
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_scale_factor) self.vae_scale_factor_spatial = vae.spatial_compression_ratio if hasattr(self, "vae") else 8
self.tokenizer_max_length = ( self.vae_scale_factor_temporal = vae.temporal_compression_ratio if hasattr(self, "vae") else 6
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
) self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
self.default_height = 480
self.default_width = 848
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds # Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline._get_t5_prompt_embeds
def _get_t5_prompt_embeds( def _get_t5_prompt_embeds(
@@ -245,7 +240,7 @@ class MochiPipeline(DiffusionPipeline):
f" {max_sequence_length} tokens: {removed_text}" f" {max_sequence_length} tokens: {removed_text}"
) )
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0] prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -340,7 +335,12 @@ class MochiPipeline(DiffusionPipeline):
dtype=dtype, dtype=dtype,
) )
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask return (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds,
negative_prompt_attention_mask,
)
def check_inputs( def check_inputs(
self, self,
@@ -424,6 +424,13 @@ class MochiPipeline(DiffusionPipeline):
""" """
self.vae.disable_tiling() self.vae.disable_tiling()
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
num_latents = latent_frames * latent_height * latent_width
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
return mask
def prepare_latents( def prepare_latents(
self, self,
batch_size, batch_size,
@@ -436,9 +443,9 @@ class MochiPipeline(DiffusionPipeline):
generator, generator,
latents=None, latents=None,
): ):
height = height // self.vae_spatial_scale_factor height = height // self.vae_scale_factor_spatial
width = width // self.vae_spatial_scale_factor width = width // self.vae_scale_factor_spatial
num_frames = (num_frames - 1) // self.vae_temporal_scale_factor + 1 num_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (batch_size, num_channels_latents, num_frames, height, width) shape = (batch_size, num_channels_latents, num_frames, height, width)
@@ -478,7 +485,7 @@ class MochiPipeline(DiffusionPipeline):
height: Optional[int] = None, height: Optional[int] = None,
width: Optional[int] = None, width: Optional[int] = None,
num_frames: int = 19, num_frames: int = 19,
num_inference_steps: int = 28, num_inference_steps: int = 50,
timesteps: List[int] = None, timesteps: List[int] = None,
guidance_scale: float = 4.5, guidance_scale: float = 4.5,
num_videos_per_prompt: Optional[int] = 1, num_videos_per_prompt: Optional[int] = 1,
@@ -501,13 +508,13 @@ class MochiPipeline(DiffusionPipeline):
prompt (`str` or `List[str]`, *optional*): prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead. instead.
height (`int`, *optional*, defaults to `self.default_height`): height (`int`, *optional*, defaults to `self.transformer.config.sample_height * self.vae.spatial_compression_ratio`):
The height in pixels of the generated image. This is set to 480 by default for the best results. The height in pixels of the generated image. This is set to 480 by default for the best results.
width (`int`, *optional*, defaults to `self.default_width`): width (`int`, *optional*, defaults to `self.transformer.config.sample_width * self.vae.spatial_compression_ratio`):
The width in pixels of the generated image. This is set to 848 by default for the best results. The width in pixels of the generated image. This is set to 848 by default for the best results.
num_frames (`int`, defaults to `19`): num_frames (`int`, defaults to `19`):
The number of video frames to generate The number of video frames to generate
num_inference_steps (`int`, *optional*, defaults to 50): num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference. expense of slower inference.
timesteps (`List[int]`, *optional*): timesteps (`List[int]`, *optional*):
@@ -567,8 +574,8 @@ class MochiPipeline(DiffusionPipeline):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
height = height or self.default_height height = height or 480 # self.transformer.config.sample_height * self.vae_scaling_factor_spatial
width = width or self.default_width width = width or 848 # self.transformer.config.sample_width * self.vae_scaling_factor_spatial
# 1. Check inputs. Raise error if not correct # 1. Check inputs. Raise error if not correct
self.check_inputs( self.check_inputs(
@@ -594,7 +601,6 @@ class MochiPipeline(DiffusionPipeline):
batch_size = prompt_embeds.shape[0] batch_size = prompt_embeds.shape[0]
device = self._execution_device device = self._execution_device
# 3. Prepare text embeddings # 3. Prepare text embeddings
( (
prompt_embeds, prompt_embeds,
@@ -613,9 +619,9 @@ class MochiPipeline(DiffusionPipeline):
max_sequence_length=max_sequence_length, max_sequence_length=max_sequence_length,
device=device, device=device,
) )
if self.do_classifier_free_guidance: # if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0) # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables # 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels num_channels_latents = self.transformer.config.in_channels
@@ -637,6 +643,9 @@ class MochiPipeline(DiffusionPipeline):
sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise) sigmas = linear_quadratic_schedule(num_inference_steps, threshold_noise)
sigmas = np.array(sigmas) sigmas = np.array(sigmas)
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
timesteps, num_inference_steps = retrieve_timesteps( timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, self.scheduler,
num_inference_steps, num_inference_steps,
@@ -653,21 +662,34 @@ class MochiPipeline(DiffusionPipeline):
if self.interrupt: if self.interrupt:
continue continue
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents # latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML # # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
# timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
latent_model_input = latents
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
noise_pred = self.transformer( noise_pred_text = self.transformer(
hidden_states=latent_model_input, hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds, encoder_hidden_states=prompt_embeds,
timestep=timestep, timestep=timestep,
encoder_attention_mask=prompt_attention_mask, encoder_attention_mask=prompt_attention_mask,
joint_attention_mask=joint_attention_mask,
return_dict=False, return_dict=False,
)[0] )[0]
if self.do_classifier_free_guidance: if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
timestep=timestep,
encoder_attention_mask=negative_prompt_attention_mask,
joint_attention_mask=negative_joint_attention_mask,
return_dict=False,
)[0]
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
else:
noise_pred = noise_pred_text
# compute the previous noisy sample x_t -> x_t-1 # compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype latents_dtype = latents.dtype
@@ -693,7 +715,6 @@ class MochiPipeline(DiffusionPipeline):
if XLA_AVAILABLE: if XLA_AVAILABLE:
xm.mark_step() xm.mark_step()
if output_type == "latent": if output_type == "latent":
video = latents video = latents
else: else: