From 4dfcfeb3aaebcffe23e9f33ffa51602eb16e28c1 Mon Sep 17 00:00:00 2001 From: Daniel Gu Date: Mon, 9 Mar 2026 08:28:24 +0100 Subject: [PATCH] Support prompt timestep embeds and prompt cross attn modulation --- .../models/transformers/transformer_ltx2.py | 47 +++++++++++++++++-- 1 file changed, 43 insertions(+), 4 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_ltx2.py b/src/diffusers/models/transformers/transformer_ltx2.py index 84db3f6eac..c00bcf2489 100644 --- a/src/diffusers/models/transformers/transformer_ltx2.py +++ b/src/diffusers/models/transformers/transformer_ltx2.py @@ -1093,6 +1093,8 @@ class LTX2VideoTransformer3DModel( pos_embed_max_pos: int = 20, base_height: int = 2048, base_width: int = 2048, + gated_attn: bool = False, + cross_attn_mod: bool = False, audio_in_channels: int = 128, # Audio Arguments audio_out_channels: int | None = 128, audio_patch_size: int = 1, @@ -1104,6 +1106,8 @@ class LTX2VideoTransformer3DModel( audio_pos_embed_max_pos: int = 20, audio_sampling_rate: int = 16000, audio_hop_length: int = 160, + audio_gated_attn: bool = False, + audio_cross_attn_mod: bool = False, num_layers: int = 48, # Shared arguments activation_fn: str = "gelu-approximate", qk_norm: str = "rms_norm_across_heads", @@ -1118,6 +1122,7 @@ class LTX2VideoTransformer3DModel( timestep_scale_multiplier: int = 1000, cross_attn_timestep_scale_multiplier: int = 1000, rope_type: str = "interleaved", + prompt_modulation: bool = False, ) -> None: super().__init__() @@ -1170,6 +1175,13 @@ class LTX2VideoTransformer3DModel( self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5) self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5) + # 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3) + if prompt_modulation: + self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False) + self.audio_prompt_adaln = LTX2AdaLayerNormSingle( + inner_dim, num_mod_params=2, use_additional_conditions=False + ) + # 4. Rotary Positional Embeddings (RoPE) # Self-Attention self.rope = LTX2AudioVideoRotaryPosEmbed( @@ -1246,6 +1258,10 @@ class LTX2VideoTransformer3DModel( audio_num_attention_heads=audio_num_attention_heads, audio_attention_head_dim=audio_attention_head_dim, audio_cross_attention_dim=audio_cross_attention_dim, + video_gated_attn=gated_attn, + video_cross_attn_adaln=cross_attn_mod, + audio_gated_attn=audio_gated_attn, + audio_cross_attn_adaln=audio_cross_attn_mod, qk_norm=qk_norm, activation_fn=activation_fn, attention_bias=attention_bias, @@ -1276,6 +1292,8 @@ class LTX2VideoTransformer3DModel( audio_encoder_hidden_states: torch.Tensor, timestep: torch.LongTensor, audio_timestep: torch.LongTensor | None = None, + sigma: torch.Tensor | None = None, + audio_sigma: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, audio_encoder_attention_mask: torch.Tensor | None = None, self_attention_mask: torch.Tensor | None = None, @@ -1307,6 +1325,13 @@ class LTX2VideoTransformer3DModel( audio_timestep (`torch.Tensor`, *optional*): Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation params. This is only used by certain pipelines such as the I2V pipeline. + sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in + models such as LTX-2.3. + audio_sigma (`torch.Tensor`, *optional*): + Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in + models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to + the provided `sigma` value. encoder_attention_mask (`torch.Tensor`, *optional*): Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`. audio_encoder_attention_mask (`torch.Tensor`, *optional*): @@ -1343,6 +1368,7 @@ class LTX2VideoTransformer3DModel( """ # Determine timestep for audio. audio_timestep = audio_timestep if audio_timestep is not None else timestep + audio_sigma = audio_sigma if audio_sigma is not None else sigma # convert encoder_attention_mask to a bias the same way we do for attention_mask if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: @@ -1413,6 +1439,19 @@ class LTX2VideoTransformer3DModel( temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1)) audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1)) + if self.config.prompt_modulation: + # LTX-2.3 + temb_prompt, _ = self.prompt_adaln( + sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype + ) + temb_prompt_audio, _ = self.audio_prompt_adaln( + audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype + ) + temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1)) + temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1)) + else: + temb_prompt = temb_prompt_audio = None + # 3.2. Prepare global modality cross attention modulation parameters video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift( timestep.flatten(), @@ -1466,8 +1505,8 @@ class LTX2VideoTransformer3DModel( audio_cross_attn_scale_shift, video_cross_attn_a2v_gate, audio_cross_attn_v2a_gate, - None, # temb_prompt - None, # temb_prompt_audio + temb_prompt, + temb_prompt_audio, video_rotary_emb, audio_rotary_emb, video_cross_attn_rotary_emb, @@ -1487,8 +1526,8 @@ class LTX2VideoTransformer3DModel( temb_ca_audio_scale_shift=audio_cross_attn_scale_shift, temb_ca_gate=video_cross_attn_a2v_gate, temb_ca_audio_gate=audio_cross_attn_v2a_gate, - temb_prompt=None, - temb_prompt_audio=None, + temb_prompt=temb_prompt, + temb_prompt_audio=temb_prompt_audio, video_rotary_emb=video_rotary_emb, audio_rotary_emb=audio_rotary_emb, ca_video_rotary_emb=video_cross_attn_rotary_emb,