Compare commits

...

2 Commits

Author SHA1 Message Date
Daniel Gu
e90b90a3cc Update DiT block for LTX 2.3 + add self_attention_mask 2026-03-07 03:32:03 +01:00
Daniel Gu
6c7e720dd8 Initial implementation of perturbed attn processor for LTX 2.3 2026-03-06 05:02:02 +01:00

View File

@@ -178,6 +178,10 @@ class LTX2AudioVideoAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -212,6 +216,112 @@ class LTX2AudioVideoAttnProcessor:
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2PerturbedAttnProcessor:
r"""
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
value = attn.to_v(encoder_hidden_states)
if all_perturbed is None:
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
if all_perturbed:
# Skip attention, use the value projection value
hidden_states = value
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if perturbation_mask is not None:
value = value.flatten(2, 3)
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
def __init__(
self,
@@ -240,6 +350,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
apply_gated_attention: bool = False,
processor=None,
):
super().__init__()
@@ -266,6 +377,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if apply_gated_attention:
# Per head gate values
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
@@ -321,6 +438,10 @@ class LTX2VideoTransformerBlock(nn.Module):
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
video_gated_attn: bool = False,
video_cross_attn_adaln: bool = False,
audio_gated_attn: bool = False,
audio_cross_attn_adaln: bool = False,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
@@ -343,6 +464,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -356,6 +479,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
# 2. Prompt Cross-Attention
@@ -370,6 +495,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -383,6 +510,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -398,6 +527,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -412,6 +543,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=LTX2AudioVideoAttnProcessor(),
)
# 4. Feedforward layers
@@ -422,14 +555,37 @@ class LTX2VideoTransformerBlock(nn.Module):
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
self.video_cross_attn_adaln = video_cross_attn_adaln
self.audio_cross_attn_adaln = audio_cross_attn_adaln
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
# Prompt cross-attn (attn2) additional modulation params
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
if self.cross_attn_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
@staticmethod
def get_mod_params(
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
) -> tuple[torch.Tensor, ...]:
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[None, None].to(temb.device)
+ temb.reshape(batch_size, temb.shape[1], num_ada_params, -1)
)
ada_params = ada_values.unbind(dim=2)
return ada_params
def forward(
self,
hidden_states: torch.Tensor,
@@ -442,6 +598,8 @@ class LTX2VideoTransformerBlock(nn.Module):
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
temb_prompt: torch.Tensor | None = None,
temb_prompt_audio: torch.Tensor | None = None,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
@@ -454,13 +612,13 @@ class LTX2VideoTransformerBlock(nn.Module):
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
norm_hidden_states = self.norm1(hidden_states)
# 1.1. Video Self-Attention
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
if self.video_cross_attn_adaln:
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
@@ -470,15 +628,15 @@ class LTX2VideoTransformerBlock(nn.Module):
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
batch_size, temb_audio.size(1), num_audio_ada_params, -1
)
# 1.2. Audio Self-Attention
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_values.unbind(dim=2)
audio_ada_params[:6]
)
if self.audio_cross_attn_adaln:
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
@@ -488,63 +646,74 @@ class LTX2VideoTransformerBlock(nn.Module):
)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
# 2. Video and Audio Cross-Attention with the text embeddings
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
if self.cross_attn_adaln:
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
shift_text_kv, scale_text_kv = video_prompt_ada_params
audio_prompt_ada_params = self.get_mod_params(self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size)
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
norm_hidden_states = self.norm2(hidden_states)
if self.video_cross_attn_adaln:
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
if self.cross_attn_adaln:
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
if self.video_cross_attn_adaln:
attn_hidden_states = attn_hidden_states * gate_text_q
hidden_states = hidden_states + attn_hidden_states
# 2.2. Audio-Text Cross-Attention
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
if self.audio_cross_attn_adaln:
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
if self.cross_attn_adaln:
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
if self.audio_cross_attn_adaln:
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# Combine global and per-layer cross attention modulation parameters
# 3.1. Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
video_ca_scale_shift_table = (
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
video_ca_gate = (
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
).unbind(dim=2)
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
a2v_gate = video_ca_gate[0].squeeze(2)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
a2v_gate = video_ca_gate_param[0].squeeze(2)
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
audio_ca_scale_shift_table = (
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
audio_ca_gate = (
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
).unbind(dim=2)
audio_ca_ada_params = self.get_mod_params(audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size)
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
v2a_gate = audio_ca_gate[0].squeeze(2)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
v2a_gate = audio_ca_gate_param[0].squeeze(2)
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
2
)
@@ -562,7 +731,7 @@ class LTX2VideoTransformerBlock(nn.Module):
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
2
)
@@ -1103,6 +1272,7 @@ class LTX2VideoTransformer3DModel(
audio_timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
self_attention_mask: torch.Tensor | None = None,
num_frames: int | None = None,
height: int | None = None,
width: int | None = None,
@@ -1135,6 +1305,8 @@ class LTX2VideoTransformer3DModel(
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
self_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
num_frames (`int`, *optional*):
The number of latent video frames. Used if calculating the video coordinates for RoPE.
height (`int`, *optional*):
@@ -1175,6 +1347,18 @@ class LTX2VideoTransformer3DModel(
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
if self_attention_mask is not None and self_attention_mask.ndim == 3:
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
# number and positive values are mapped to their logarithm.
dtype_finfo = torch.finfo(hidden_states.dtype)
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
unmasked_entries = self_attention_mask > 0
if torch.any(unmasked_entries):
additive_self_attn_mask[unmasked_entries] = torch.log(
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
).to(hidden_states.dtype)
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
batch_size = hidden_states.size(0)
# 1. Prepare RoPE positional embeddings