mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-09 18:21:48 +08:00
Support prompt timestep embeds and prompt cross attn modulation
This commit is contained in:
@@ -1093,6 +1093,8 @@ class LTX2VideoTransformer3DModel(
|
||||
pos_embed_max_pos: int = 20,
|
||||
base_height: int = 2048,
|
||||
base_width: int = 2048,
|
||||
gated_attn: bool = False,
|
||||
cross_attn_mod: bool = False,
|
||||
audio_in_channels: int = 128, # Audio Arguments
|
||||
audio_out_channels: int | None = 128,
|
||||
audio_patch_size: int = 1,
|
||||
@@ -1104,6 +1106,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_pos_embed_max_pos: int = 20,
|
||||
audio_sampling_rate: int = 16000,
|
||||
audio_hop_length: int = 160,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_mod: bool = False,
|
||||
num_layers: int = 48, # Shared arguments
|
||||
activation_fn: str = "gelu-approximate",
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
@@ -1118,6 +1122,7 @@ class LTX2VideoTransformer3DModel(
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: str = "interleaved",
|
||||
prompt_modulation: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1170,6 +1175,13 @@ class LTX2VideoTransformer3DModel(
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
|
||||
|
||||
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
|
||||
if prompt_modulation:
|
||||
self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
|
||||
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
|
||||
inner_dim, num_mod_params=2, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 4. Rotary Positional Embeddings (RoPE)
|
||||
# Self-Attention
|
||||
self.rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
@@ -1246,6 +1258,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_attention_heads=audio_num_attention_heads,
|
||||
audio_attention_head_dim=audio_attention_head_dim,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
video_gated_attn=gated_attn,
|
||||
video_cross_attn_adaln=cross_attn_mod,
|
||||
audio_gated_attn=audio_gated_attn,
|
||||
audio_cross_attn_adaln=audio_cross_attn_mod,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
@@ -1276,6 +1292,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: torch.LongTensor | None = None,
|
||||
sigma: torch.Tensor | None = None,
|
||||
audio_sigma: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
@@ -1307,6 +1325,13 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_timestep (`torch.Tensor`, *optional*):
|
||||
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
|
||||
params. This is only used by certain pipelines such as the I2V pipeline.
|
||||
sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in
|
||||
models such as LTX-2.3.
|
||||
audio_sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in
|
||||
models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to
|
||||
the provided `sigma` value.
|
||||
encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
|
||||
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
@@ -1343,6 +1368,7 @@ class LTX2VideoTransformer3DModel(
|
||||
"""
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
audio_sigma = audio_sigma if audio_sigma is not None else sigma
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
@@ -1413,6 +1439,19 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
if self.config.prompt_modulation:
|
||||
# LTX-2.3
|
||||
temb_prompt, _ = self.prompt_adaln(
|
||||
sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
temb_prompt_audio, _ = self.audio_prompt_adaln(
|
||||
audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
|
||||
)
|
||||
temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
|
||||
temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
|
||||
else:
|
||||
temb_prompt = temb_prompt_audio = None
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
timestep.flatten(),
|
||||
@@ -1466,8 +1505,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_cross_attn_scale_shift,
|
||||
video_cross_attn_a2v_gate,
|
||||
audio_cross_attn_v2a_gate,
|
||||
None, # temb_prompt
|
||||
None, # temb_prompt_audio
|
||||
temb_prompt,
|
||||
temb_prompt_audio,
|
||||
video_rotary_emb,
|
||||
audio_rotary_emb,
|
||||
video_cross_attn_rotary_emb,
|
||||
@@ -1487,8 +1526,8 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
|
||||
temb_ca_gate=video_cross_attn_a2v_gate,
|
||||
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
|
||||
temb_prompt=None,
|
||||
temb_prompt_audio=None,
|
||||
temb_prompt=temb_prompt,
|
||||
temb_prompt_audio=temb_prompt_audio,
|
||||
video_rotary_emb=video_rotary_emb,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
ca_video_rotary_emb=video_cross_attn_rotary_emb,
|
||||
|
||||
Reference in New Issue
Block a user