Compare commits

...

6 Commits

4 changed files with 701 additions and 66 deletions

View File

@@ -237,7 +237,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTXVideoDownsampler3d(nn.Module):
class LTX2VideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -285,10 +285,11 @@ class LTXVideoDownsampler3d(nn.Module):
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTXVideoUpsampler3d(nn.Module):
class LTX2VideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
stride: int | tuple[int, int, int] = 1,
residual: bool = False,
upscale_factor: int = 1,
@@ -300,7 +301,8 @@ class LTXVideoUpsampler3d(nn.Module):
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
out_channels = out_channels or in_channels
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
@@ -408,7 +410,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
@@ -417,7 +419,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
@@ -426,7 +428,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
@@ -580,6 +582,7 @@ class LTX2VideoUpBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
@@ -609,17 +612,38 @@ class LTX2VideoUpBlock3d(nn.Module):
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
self.upsamplers = nn.ModuleList()
if upsample_type == "spatial":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(1, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "temporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 1, 1),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
elif upsample_type == "spatiotemporal":
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=(2, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
]
)
)
resnets = []
for _ in range(num_layers):
@@ -862,6 +886,7 @@ class LTX2VideoDecoder3d(nn.Module):
block_out_channels: tuple[int, ...] = (256, 512, 1024),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
@@ -917,6 +942,7 @@ class LTX2VideoDecoder3d(nn.Module):
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
upsample_type=upsample_type[i],
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
@@ -1062,6 +1088,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[int, ...] = (2, 2, 2),
timestep_conditioning: bool = False,
@@ -1098,6 +1125,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
upsample_type=upsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,

View File

@@ -178,6 +178,10 @@ class LTX2AudioVideoAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -212,6 +216,112 @@ class LTX2AudioVideoAttnProcessor:
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2PerturbedAttnProcessor:
r"""
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
value = attn.to_v(encoder_hidden_states)
if all_perturbed is None:
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
if all_perturbed:
# Skip attention, use the value projection value
hidden_states = value
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if perturbation_mask is not None:
value = value.flatten(2, 3)
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
def __init__(
self,
@@ -240,6 +350,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
apply_gated_attention: bool = False,
processor=None,
):
super().__init__()
@@ -266,6 +377,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if apply_gated_attention:
# Per head gate values
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
@@ -321,6 +438,10 @@ class LTX2VideoTransformerBlock(nn.Module):
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
video_gated_attn: bool = False,
video_cross_attn_adaln: bool = False,
audio_gated_attn: bool = False,
audio_cross_attn_adaln: bool = False,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
@@ -328,9 +449,15 @@ class LTX2VideoTransformerBlock(nn.Module):
eps: float = 1e-6,
elementwise_affine: bool = False,
rope_type: str = "interleaved",
perturbed_attn: bool = False,
):
super().__init__()
if perturbed_attn:
attn_processor_cls = LTX2PerturbedAttnProcessor
else:
attn_processor_cls = LTX2AudioVideoAttnProcessor
# 1. Self-Attention (video and audio)
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = LTX2Attention(
@@ -343,6 +470,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -356,6 +485,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 2. Prompt Cross-Attention
@@ -370,6 +501,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -383,6 +516,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -398,6 +533,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -412,6 +549,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 4. Feedforward layers
@@ -422,14 +561,37 @@ class LTX2VideoTransformerBlock(nn.Module):
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
self.video_cross_attn_adaln = video_cross_attn_adaln
self.audio_cross_attn_adaln = audio_cross_attn_adaln
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
# Prompt cross-attn (attn2) additional modulation params
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
if self.cross_attn_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
@staticmethod
def get_mod_params(
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
) -> tuple[torch.Tensor, ...]:
num_ada_params = scale_shift_table.shape[0]
ada_values = (
scale_shift_table[None, None].to(temb.device)
+ temb.reshape(batch_size, temb.shape[1], num_ada_params, -1)
)
ada_params = ada_values.unbind(dim=2)
return ada_params
def forward(
self,
hidden_states: torch.Tensor,
@@ -442,6 +604,8 @@ class LTX2VideoTransformerBlock(nn.Module):
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
temb_prompt: torch.Tensor | None = None,
temb_prompt_audio: torch.Tensor | None = None,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
@@ -454,13 +618,13 @@ class LTX2VideoTransformerBlock(nn.Module):
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
norm_hidden_states = self.norm1(hidden_states)
# 1.1. Video Self-Attention
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
if self.video_cross_attn_adaln:
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
@@ -470,15 +634,15 @@ class LTX2VideoTransformerBlock(nn.Module):
)
hidden_states = hidden_states + attn_hidden_states * gate_msa
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
batch_size, temb_audio.size(1), num_audio_ada_params, -1
)
# 1.2. Audio Self-Attention
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_values.unbind(dim=2)
audio_ada_params[:6]
)
if self.audio_cross_attn_adaln:
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
@@ -488,63 +652,74 @@ class LTX2VideoTransformerBlock(nn.Module):
)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
# 2. Video and Audio Cross-Attention with the text embeddings
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
if self.cross_attn_adaln:
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
shift_text_kv, scale_text_kv = video_prompt_ada_params
audio_prompt_ada_params = self.get_mod_params(self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size)
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
norm_hidden_states = self.norm2(hidden_states)
if self.video_cross_attn_adaln:
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
if self.cross_attn_adaln:
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
if self.video_cross_attn_adaln:
attn_hidden_states = attn_hidden_states * gate_text_q
hidden_states = hidden_states + attn_hidden_states
# 2.2. Audio-Text Cross-Attention
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
if self.audio_cross_attn_adaln:
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
if self.cross_attn_adaln:
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
if self.audio_cross_attn_adaln:
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# Combine global and per-layer cross attention modulation parameters
# 3.1. Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
video_ca_scale_shift_table = (
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
video_ca_gate = (
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
).unbind(dim=2)
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
a2v_gate = video_ca_gate[0].squeeze(2)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
a2v_gate = video_ca_gate_param[0].squeeze(2)
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
audio_ca_scale_shift_table = (
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
audio_ca_gate = (
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
).unbind(dim=2)
audio_ca_ada_params = self.get_mod_params(audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size)
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
v2a_gate = audio_ca_gate[0].squeeze(2)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
v2a_gate = audio_ca_gate_param[0].squeeze(2)
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
2
)
@@ -562,7 +737,7 @@ class LTX2VideoTransformerBlock(nn.Module):
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
2
)
@@ -1103,6 +1278,7 @@ class LTX2VideoTransformer3DModel(
audio_timestep: torch.LongTensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
self_attention_mask: torch.Tensor | None = None,
num_frames: int | None = None,
height: int | None = None,
width: int | None = None,
@@ -1135,6 +1311,8 @@ class LTX2VideoTransformer3DModel(
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
self_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
num_frames (`int`, *optional*):
The number of latent video frames. Used if calculating the video coordinates for RoPE.
height (`int`, *optional*):
@@ -1175,6 +1353,18 @@ class LTX2VideoTransformer3DModel(
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
if self_attention_mask is not None and self_attention_mask.ndim == 3:
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
# number and positive values are mapped to their logarithm.
dtype_finfo = torch.finfo(hidden_states.dtype)
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
unmasked_entries = self_attention_mask > 0
if torch.any(unmasked_entries):
additive_self_attn_mask[unmasked_entries] = torch.log(
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
).to(hidden_states.dtype)
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
batch_size = hidden_states.size(0)
# 1. Prepare RoPE positional embeddings

View File

@@ -44,7 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_ltx2_condition import LTX2ConditionPipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
else:
import sys

View File

@@ -8,6 +8,175 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
"""
Creates a Kaiser sinc kernel for low-pass filtering.
Args:
cutoff (`float`):
Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist
frequency).
half_width (`float`):
Used to determine the Kaiser window's beta parameter.
kernel_size:
Size of the Kaiser window (and ultimately the Kaiser sinc kernel).
Returns:
`torch.Tensor` of shape `(kernel_size,)`:
The Kaiser sinc kernel.
"""
delta_f = 4 * half_width
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
even = kernel_size % 2 == 0
half_size = kernel_size // 2
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
if cutoff == 0.0:
filter = torch.zeros_like(time)
else:
time = 2 * cutoff * time
sinc = torch.where(
time == 0,
torch.ones_like(time),
torch.sin(math.pi * time) / math.pi / time,
)
filter = 2 * cutoff * window * sinc
filter = filter / filter.sum()
return filter
class DownSample1d(nn.Module):
"""1D low-pass filter for antialias downsampling."""
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
use_padding: bool = True,
padding_mode: str = "replicate",
persistent: bool = True,
):
super().__init__()
self.ratio = ratio
self.kernel_size = kernel_size or int(6 * ratio // 2) * 2
self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1
self.pad_right = self.kernel_size // 2
self.use_padding = use_padding
self.padding_mode = padding_mode
cutoff = 0.5 / ratio
half_width = 0.6 / ratio
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
self.register_buffer("filter", low_pass_filter.view(1, 1, self,kernel_size), persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x expected shape: [batch_size, num_channels, hidden_dim]
num_channels = x.shape[1]
if self.use_padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels)
return x_filtered
class UpSample1d(nn.Module):
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
window_type: str = "kaiser",
padding_mode: str = "replicate",
persistent: bool = True,
):
super().__init__()
self.ratio = ratio
self.padding_mode = padding_mode
if window_type == "hann":
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser sinc filter is BigVGAN default
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
sinc_filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
self.register_buffer("filter", sinc_filter, persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x expected shape: [batch_size, num_channels, hidden_dim]
num_channels = x.shape[1]
x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode)
low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1)
x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels)
return x[..., self.pad_left:-self.pad_right]
class SnakeBeta(nn.Module):
"""
Implements the Snake and SnakeBeta activations, which help with learning periodic patterns.
"""
def __init__(
self,
channels: int,
alpha: float = 1.0,
eps: float = 1e-9,
trainable_params: bool = True,
logscale: bool = True,
use_beta: bool = True,
):
self.eps = eps
self.logscale = logscale
self.use_beta = use_beta
self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
self.alpha.requires_grad = trainable_params
if use_beta:
self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
self.beta.requires_grad = trainable_params
def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
broadcast_shape = [1] * hidden_states.ndim
broadcast_shape[channel_dim] = -1
alpha = self.alpha.view(broadcast_shape)
if self.use_beta:
beta = self.beta.view(broadcast_shape)
if self.logscale:
alpha = torch.exp(alpha)
if self.use_beta:
beta = torch.exp(beta)
amplitude = beta if self.use_beta else alpha
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2)
return hidden_states
class ResBlock(nn.Module):
def __init__(
self,
@@ -15,12 +184,15 @@ class ResBlock(nn.Module):
kernel_size: int = 3,
stride: int = 1,
dilations: tuple[int, ...] = (1, 3, 5),
act_fn: str = "leaky_relu",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = False,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
@@ -28,6 +200,22 @@ class ResBlock(nn.Module):
for dilation in dilations
]
)
self.acts1 = nn.ModuleList()
for _ in range(len(self.convs1)):
if act_fn == "snakebeta":
act = SnakeBeta(channels, use_beta=True)
elif act_fn == "snake":
act = SnakeBeta(channels, use_beta=False)
else:
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
if antialias:
act = nn.Sequential(
UpSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
act,
DownSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
)
self.acts1.append(act)
self.convs2 = nn.ModuleList(
[
@@ -35,12 +223,28 @@ class ResBlock(nn.Module):
for _ in range(len(dilations))
]
)
self.acts2 = nn.ModuleList()
for _ in range(len(self.convs2)):
if act_fn == "snakebeta":
act = SnakeBeta(channels, use_beta=True)
elif act_fn == "snake":
act = SnakeBeta(channels, use_beta=False)
else:
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
if antialias:
act = nn.Sequential(
UpSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
act,
DownSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
)
self.acts2.append(act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2):
xt = act1(xt)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = act2(xt)
xt = conv2(xt)
x = x + xt
return x
@@ -61,7 +265,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
upsample_factors: list[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: list[int] = [3, 7, 11],
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
act_fn: str = "leaky_relu",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = False,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
final_act_fn: str | None = "tanh",
final_bias: bool = True,
output_sampling_rate: int = 24000,
):
super().__init__()
@@ -69,7 +279,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.act_fn = act_fn
self.negative_slope = leaky_relu_negative_slope
self.final_act_fn = final_act_fn
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
@@ -83,6 +295,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
supported_act_fns = ["snakebeta", "snake", "leaky_relu"]
if self.act_fn not in supported_act_fns:
raise ValueError(
f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are "
f"{supported_act_fns}."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
@@ -103,15 +322,30 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
channels=output_channels,
kernel_size=kernel_size,
dilations=dilations,
act_fn=act_fn,
leaky_relu_negative_slope=leaky_relu_negative_slope,
antialias=antialias,
antialias_ratio=antialias_ratio,
antialias_kernel_size=antialias_kernel_size,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
if act_fn == "snakebeta" or act_fn == "snake":
# Always use antialiasing
self.act_out = nn.Sequential(
UpSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
SnakeBeta(channels=out_channels, use_beta=True),
DownSample1d(ratio=antialias_ratio, kernel_size=antialias_kernel_size),
)
elif act_fn == "leaky_relu":
# NOTE: does NOT use self.negative_slope, following the original code
self.act_out = nn.LeakyReLU()
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
@@ -139,7 +373,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
if self.act_fn == "leaky_relu":
# Other activations are inside each upsampling block
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
@@ -149,10 +385,191 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.act_out(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
if self.final_act_fn == "tanh":
hidden_states = torch.tanh(hidden_states)
else:
hidden_states = torch.clamp(hidden_states, -1, 1)
return hidden_states
class CausalSTFT(nn.Module):
"""
Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases
multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the
exact buffers should be loaded from the checkpoint in bfloat16.
"""
def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512):
super().__init__()
self.hop_length = hop_length
self.window_length = window_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if waveform.ndim == 2:
waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples]
left_pad = max(0, self.window_length - self.hop_length) # causal: left-only
waveform = F.pad(waveform, (left_pad, 0))
spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""
Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be
loaded from the checkpoint in bfloat16.
"""
def __init__(
self,
filter_length: int = 512,
hop_length: int = 80,
window_length: int = 512,
num_mel_channels: int = 64,
):
super().__init__()
self.stft_fn = CausalSTFT(filter_length, hop_length, window_length)
num_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True)
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
magnitude, phase = self.stft_fn(waveform)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class LTX2VocoderWithBWE(ModelMixin, ConfigMixin):
"""
LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the
BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same
architecture as the original vocoder.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1536,
out_channels: int = 2,
upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4],
upsample_factors: list[int] = [5, 2, 2, 2, 2, 2],
resnet_kernel_sizes: list[int] = [3, 7, 11],
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
act_fn: str = "snakebeta",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = True,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
final_act_fn: str | None = None,
final_bias: bool = False,
bwe_in_channels: int = 128,
bwe_hidden_channels: int = 512,
bwe_out_channels: int = 2,
bwe_upsample_kernel_sizes: list[int] = [12, 11, 8, 4, 4],
bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2],
bwe_resnet_kernel_sizes: list[int] = [3, 7, 11],
bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
bwe_act_fn: str = "snakebeta",
bwe_leaky_relu_negative_slope: float = 0.1,
bwe_antialias: bool = True,
bwe_antialias_ratio: int = 2,
bwe_antialias_kernel_size: int = 12,
bwe_final_act_fn: str | None = None,
bwe_final_bias: bool = False,
filter_length: int = 512,
hop_length: int = 80,
window_length: int = 512,
num_mel_channels: int = 64,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
):
super().__init__()
self.vocoder = LTX2Vocoder(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
upsample_kernel_sizes=upsample_kernel_sizes,
upsample_factors=upsample_factors,
resnet_kernel_sizes=resnet_kernel_sizes,
resnet_dilations=resnet_dilations,
act_fn=act_fn,
leaky_relu_negative_slope=leaky_relu_negative_slope,
antialias=antialias,
antialias_ratio=antialias_ratio,
antialias_kernel_size=antialias_kernel_size,
final_act_fn=final_act_fn,
final_bias=final_bias,
output_sampling_rate=input_sampling_rate,
)
self.bwe_generator = LTX2Vocoder(
in_channels=bwe_in_channels,
hidden_channels=bwe_hidden_channels,
out_channels=bwe_out_channels,
upsample_kernel_sizes=bwe_upsample_kernel_sizes,
upsample_factors=bwe_upsample_factors,
resnet_kernel_sizes=bwe_resnet_kernel_sizes,
resnet_dilations=bwe_resnet_dilations,
act_fn=bwe_act_fn,
leaky_relu_negative_slope=bwe_leaky_relu_negative_slope,
antialias=bwe_antialias,
antialias_ratio=bwe_antialias_ratio,
antialias_kernel_size=bwe_antialias_kernel_size,
final_act_fn=bwe_final_act_fn,
final_bias=bwe_final_bias,
output_sampling_rate=output_sampling_rate,
)
self.mel_stft = MelSTFT(
filter_length=filter_length,
hop_length=hop_length,
window_length=window_length,
num_mel_channels=num_mel_channels,
)
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate,
window_type="hann",
persistent=False,
)
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
# 1. Run stage 1 vocoder to get low sampling rate waveform
x = self.vocoder(mel_spec)
batch_size, num_channels, num_samples = x.shape
# Pad to exact multiple of hop_length for exact mel frame count
remainder = num_samples % self.config.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
# 2. Compute mel spectrogram on vocoder output
x = x.flatten(0, 1)
mel, _, _, _ = self.mel_stft(x)
mel = mel.unflatten(0, (-1, num_channels))
# 3. Run bandwidth extender (BWE) on new mel spectrogram
mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins]
residual = self.bwe_generator(mel_for_bwe)
# 4. Residual connection with resampler
skip = self.resampler(x)
waveform = torch.clamp(residual + skip, -1, 1)
output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate
waveform = waveform[..., :output_samples]
return waveform