Compare commits

...

2 Commits

Author SHA1 Message Date
Dhruv Nair
30dd9f6845 update 2024-11-18 17:50:51 +01:00
Dhruv Nair
27f81bd54f update 2024-11-18 17:30:24 +01:00
4 changed files with 30 additions and 8 deletions

View File

@@ -3579,9 +3579,16 @@ class MochiAttnProcessor2_0:
key = torch.cat([key, encoder_key], dim=2)
value = torch.cat([value, encoder_value], dim=2)
# Zero out tokens based on the attention mask
query = query * attention_mask[:, None, :, None]
key = key * attention_mask[:, None, :, None]
value = value * attention_mask[:, None, :, None]
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
# Zero out tokens based on attention mask
hidden_states = hidden_states * attention_mask[:, :, None]
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
(sequence_length, encoder_sequence_length), dim=1

View File

@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
height, width = latent.shape[-2:]
else:
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC

View File

@@ -128,6 +128,7 @@ class MochiTransformerBlock(nn.Module):
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: Optional[torch.Tensor] = None,
joint_attention_mask=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
@@ -137,11 +138,11 @@ class MochiTransformerBlock(nn.Module):
)
else:
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
attn_hidden_states, context_attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=joint_attention_mask,
)
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
@@ -324,6 +325,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_attention_mask: torch.Tensor,
joint_attention_mask=None,
return_dict: bool = True,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
@@ -373,6 +375,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_mask=joint_attention_mask,
)
hidden_states = self.norm_out(hidden_states, temb)

View File

@@ -17,6 +17,7 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers import T5EncoderModel, T5TokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -245,7 +246,7 @@ class MochiPipeline(DiffusionPipeline):
f" {max_sequence_length} tokens: {removed_text}"
)
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
# duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -258,6 +259,14 @@ class MochiPipeline(DiffusionPipeline):
return prompt_embeds, prompt_attention_mask
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
num_latents = latent_frames * latent_height * latent_width
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
return mask
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
def encode_prompt(
self,
@@ -613,10 +622,6 @@ class MochiPipeline(DiffusionPipeline):
max_sequence_length=max_sequence_length,
device=device,
)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
# 4. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
@@ -630,6 +635,13 @@ class MochiPipeline(DiffusionPipeline):
generator,
latents,
)
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0)
# 5. Prepare timestep
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
@@ -662,6 +674,7 @@ class MochiPipeline(DiffusionPipeline):
encoder_hidden_states=prompt_embeds,
timestep=timestep,
encoder_attention_mask=prompt_attention_mask,
joint_attention_mask=joint_attention_mask,
return_dict=False,
)[0]