mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-08 09:41:50 +08:00
Compare commits
2 Commits
helios-mod
...
ltx2-3-pip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e90b90a3cc | ||
|
|
6c7e720dd8 |
@@ -178,6 +178,10 @@ class LTX2AudioVideoAttnProcessor:
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -212,6 +216,112 @@ class LTX2AudioVideoAttnProcessor:
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2PerturbedAttnProcessor:
|
||||
r"""
|
||||
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
if all_perturbed is None:
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
|
||||
if all_perturbed:
|
||||
# Skip attention, use the value projection value
|
||||
hidden_states = value
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
if attn.rope_type == "interleaved":
|
||||
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_interleaved_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
elif attn.rope_type == "split":
|
||||
query = apply_split_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_split_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if perturbation_mask is not None:
|
||||
value = value.flatten(2, 3)
|
||||
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
@@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -240,6 +350,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -266,6 +377,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if apply_gated_attention:
|
||||
# Per head gate values
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
@@ -321,6 +438,10 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
audio_num_attention_heads: int,
|
||||
audio_attention_head_dim,
|
||||
audio_cross_attention_dim: int,
|
||||
video_gated_attn: bool = False,
|
||||
video_cross_attn_adaln: bool = False,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_adaln: bool = False,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = True,
|
||||
@@ -343,6 +464,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -356,6 +479,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Prompt Cross-Attention
|
||||
@@ -370,6 +495,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -383,6 +510,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
@@ -398,6 +527,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
@@ -412,6 +543,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
# 4. Feedforward layers
|
||||
@@ -422,14 +555,37 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
|
||||
|
||||
# 5. Per-Layer Modulation Parameters
|
||||
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
|
||||
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
|
||||
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
|
||||
self.video_cross_attn_adaln = video_cross_attn_adaln
|
||||
self.audio_cross_attn_adaln = audio_cross_attn_adaln
|
||||
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
|
||||
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Prompt cross-attn (attn2) additional modulation params
|
||||
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
|
||||
if self.cross_attn_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
|
||||
|
||||
# Per-layer a2v, v2a Cross-Attention mod params
|
||||
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
|
||||
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
|
||||
|
||||
@staticmethod
|
||||
def get_mod_params(
|
||||
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
ada_values = (
|
||||
scale_shift_table[None, None].to(temb.device)
|
||||
+ temb.reshape(batch_size, temb.shape[1], num_ada_params, -1)
|
||||
)
|
||||
ada_params = ada_values.unbind(dim=2)
|
||||
return ada_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -442,6 +598,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
temb_ca_audio_scale_shift: torch.Tensor,
|
||||
temb_ca_gate: torch.Tensor,
|
||||
temb_ca_audio_gate: torch.Tensor,
|
||||
temb_prompt: torch.Tensor | None = None,
|
||||
temb_prompt_audio: torch.Tensor | None = None,
|
||||
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
@@ -454,13 +612,13 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Video and Audio Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# 1.1. Video Self-Attention
|
||||
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
|
||||
if self.video_cross_attn_adaln:
|
||||
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
@@ -470,15 +628,15 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
|
||||
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
|
||||
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
|
||||
batch_size, temb_audio.size(1), num_audio_ada_params, -1
|
||||
)
|
||||
# 1.2. Audio Self-Attention
|
||||
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_values.unbind(dim=2)
|
||||
audio_ada_params[:6]
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(
|
||||
@@ -488,63 +646,74 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
)
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
|
||||
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
|
||||
if self.cross_attn_adaln:
|
||||
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
|
||||
shift_text_kv, scale_text_kv = video_prompt_ada_params
|
||||
|
||||
audio_prompt_ada_params = self.get_mod_params(self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size)
|
||||
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
|
||||
|
||||
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.video_cross_attn_adaln:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.video_cross_attn_adaln:
|
||||
attn_hidden_states = attn_hidden_states * gate_text_q
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
# 2.2. Audio-Text Cross-Attention
|
||||
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
|
||||
if self.audio_cross_attn_adaln:
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn2(
|
||||
norm_audio_hidden_states,
|
||||
encoder_hidden_states=audio_encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
|
||||
# Combine global and per-layer cross attention modulation parameters
|
||||
# 3.1. Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
video_ca_scale_shift_table = (
|
||||
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
|
||||
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_gate = (
|
||||
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
|
||||
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
|
||||
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
|
||||
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
|
||||
a2v_gate = video_ca_gate[0].squeeze(2)
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
|
||||
a2v_gate = video_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
audio_ca_scale_shift_table = (
|
||||
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
|
||||
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_gate = (
|
||||
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
|
||||
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_ada_params = self.get_mod_params(audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size)
|
||||
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
|
||||
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
|
||||
v2a_gate = audio_ca_gate[0].squeeze(2)
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
|
||||
v2a_gate = audio_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
@@ -562,7 +731,7 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
|
||||
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
@@ -1103,6 +1272,7 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_timestep: torch.LongTensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
@@ -1135,6 +1305,8 @@ class LTX2VideoTransformer3DModel(
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
|
||||
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
|
||||
self_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
|
||||
num_frames (`int`, *optional*):
|
||||
The number of latent video frames. Used if calculating the video coordinates for RoPE.
|
||||
height (`int`, *optional*):
|
||||
@@ -1175,6 +1347,18 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
|
||||
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
if self_attention_mask is not None and self_attention_mask.ndim == 3:
|
||||
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
|
||||
# number and positive values are mapped to their logarithm.
|
||||
dtype_finfo = torch.finfo(hidden_states.dtype)
|
||||
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
|
||||
unmasked_entries = self_attention_mask > 0
|
||||
if torch.any(unmasked_entries):
|
||||
additive_self_attn_mask[unmasked_entries] = torch.log(
|
||||
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
|
||||
).to(hidden_states.dtype)
|
||||
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Prepare RoPE positional embeddings
|
||||
|
||||
Reference in New Issue
Block a user