mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-06 20:44:33 +08:00
Compare commits
2 Commits
attn-refac
...
mochi-attn
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
30dd9f6845 | ||
|
|
27f81bd54f |
@@ -3579,9 +3579,16 @@ class MochiAttnProcessor2_0:
|
||||
key = torch.cat([key, encoder_key], dim=2)
|
||||
value = torch.cat([value, encoder_value], dim=2)
|
||||
|
||||
# Zero out tokens based on the attention mask
|
||||
query = query * attention_mask[:, None, :, None]
|
||||
key = key * attention_mask[:, None, :, None]
|
||||
value = value * attention_mask[:, None, :, None]
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
# Zero out tokens based on attention mask
|
||||
hidden_states = hidden_states * attention_mask[:, :, None]
|
||||
|
||||
hidden_states, encoder_hidden_states = hidden_states.split_with_sizes(
|
||||
(sequence_length, encoder_sequence_length), dim=1
|
||||
|
||||
@@ -262,7 +262,6 @@ class PatchEmbed(nn.Module):
|
||||
height, width = latent.shape[-2:]
|
||||
else:
|
||||
height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
|
||||
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
|
||||
@@ -128,6 +128,7 @@ class MochiTransformerBlock(nn.Module):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
joint_attention_mask=None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
|
||||
|
||||
@@ -137,11 +138,11 @@ class MochiTransformerBlock(nn.Module):
|
||||
)
|
||||
else:
|
||||
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
|
||||
|
||||
attn_hidden_states, context_attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=joint_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
|
||||
@@ -324,6 +325,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_attention_mask: torch.Tensor,
|
||||
joint_attention_mask=None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
@@ -373,6 +375,7 @@ class MochiTransformer3DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_mask=joint_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
|
||||
@@ -17,6 +17,7 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -245,7 +246,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)[0]
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
@@ -258,6 +259,14 @@ class MochiPipeline(DiffusionPipeline):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
def prepare_joint_attention_mask(self, prompt_attention_mask, latents):
|
||||
batch_size, channels, latent_frames, latent_height, latent_width = latents.shape
|
||||
num_latents = latent_frames * latent_height * latent_width
|
||||
num_visual_tokens = num_latents // (self.transformer.config.patch_size**2)
|
||||
mask = F.pad(prompt_attention_mask, (num_visual_tokens, 0), value=True)
|
||||
|
||||
return mask
|
||||
|
||||
# Adapted from diffusers.pipelines.cogvideo.pipeline_cogvideox.CogVideoXPipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
@@ -613,10 +622,6 @@ class MochiPipeline(DiffusionPipeline):
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
@@ -630,6 +635,13 @@ class MochiPipeline(DiffusionPipeline):
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
joint_attention_mask = self.prepare_joint_attention_mask(prompt_attention_mask, latents)
|
||||
negative_joint_attention_mask = self.prepare_joint_attention_mask(negative_prompt_attention_mask, latents)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
joint_attention_mask = torch.cat([negative_joint_attention_mask, joint_attention_mask], dim=0)
|
||||
|
||||
# 5. Prepare timestep
|
||||
# from https://github.com/genmoai/models/blob/075b6e36db58f1242921deff83a1066887b9c9e1/src/mochi_preview/infer.py#L77
|
||||
@@ -662,6 +674,7 @@ class MochiPipeline(DiffusionPipeline):
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timestep,
|
||||
encoder_attention_mask=prompt_attention_mask,
|
||||
joint_attention_mask=joint_attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user