mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-18 22:48:08 +08:00
Compare commits
10 Commits
ltx2-3-pip
...
fa4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae76da7cdb | ||
|
|
9cffd7a6d2 | ||
|
|
8e4b5607ed | ||
|
|
c6f72ad2f6 | ||
|
|
9a28c2f020 | ||
|
|
11a3284cee | ||
|
|
16e7067647 | ||
|
|
d1b3555c29 | ||
|
|
9677859ebf | ||
|
|
ed31974c3e |
@@ -22,6 +22,8 @@
|
||||
title: Reproducibility
|
||||
- local: using-diffusers/schedulers
|
||||
title: Schedulers
|
||||
- local: using-diffusers/guiders
|
||||
title: Guiders
|
||||
- local: using-diffusers/automodel
|
||||
title: AutoModel
|
||||
- local: using-diffusers/other-formats
|
||||
@@ -110,8 +112,6 @@
|
||||
title: ModularPipeline
|
||||
- local: modular_diffusers/components_manager
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- local: modular_diffusers/custom_blocks
|
||||
title: Building Custom Blocks
|
||||
- local: modular_diffusers/mellon
|
||||
|
||||
@@ -99,7 +99,7 @@ To update guider configuration, you can run `pipe.guider = pipe.guider.new(...)`
|
||||
pipe.guider = pipe.guider.new(guidance_scale=5.0)
|
||||
```
|
||||
|
||||
Read more on Guider [here](../../modular_diffusers/guiders).
|
||||
Read more on Guider [here](../../using-diffusers/guiders).
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ HunyuanImage-2.1 comes in the following variants:
|
||||
|
||||
## HunyuanImage-2.1
|
||||
|
||||
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../modular_diffusers/guiders.md)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
|
||||
HunyuanImage-2.1 applies [Adaptive Projected Guidance (APG)](https://huggingface.co/papers/2410.02416) combined with Classifier-Free Guidance (CFG) in the denoising loop. `HunyuanImagePipeline` has a `guider` component (read more about [Guider](../../using-diffusers/guiders)) and does not take a `guidance_scale` parameter at runtime. To change guider-related parameters, e.g., `guidance_scale`, you can update the `guider` configuration instead.
|
||||
|
||||
```python
|
||||
import torch
|
||||
|
||||
@@ -338,7 +338,7 @@ guider = ClassifierFreeGuidance(guidance_scale=5.0)
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
|
||||
See the [Guiders](../using-diffusers/guiders) guide for more details on available guiders and how to configure them.
|
||||
|
||||
## Splitting a pipeline into stages
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ The Modular Diffusers docs are organized as shown below.
|
||||
|
||||
- [ModularPipeline](./modular_pipeline) shows you how to create and convert pipeline blocks into an executable [`ModularPipeline`].
|
||||
- [ComponentsManager](./components_manager) shows you how to manage and reuse components across multiple pipelines.
|
||||
- [Guiders](./guiders) shows you how to use different guidance methods in the pipeline.
|
||||
- [Guiders](../using-diffusers/guiders) shows you how to use different guidance methods in the pipeline.
|
||||
|
||||
## Mellon Integration
|
||||
|
||||
|
||||
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -482,144 +482,6 @@ print(
|
||||
) # (2880, 1, 960, 320) having a stride of 1 for the 2nd dimension proves that it works
|
||||
```
|
||||
|
||||
## torch.jit.trace
|
||||
|
||||
[torch.jit.trace](https://pytorch.org/docs/stable/generated/torch.jit.trace.html) records the operations a model performs on a sample input and creates a new, optimized representation of the model based on the recorded execution path. During tracing, the model is optimized to reduce overhead from Python and dynamic control flows and operations are fused together for more efficiency. The returned executable or [ScriptFunction](https://pytorch.org/docs/stable/generated/torch.jit.ScriptFunction.html) can be compiled.
|
||||
|
||||
```py
|
||||
import time
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import functools
|
||||
|
||||
# torch disable grad
|
||||
torch.set_grad_enabled(False)
|
||||
|
||||
# set variables
|
||||
n_experiments = 2
|
||||
unet_runs_per_experiment = 50
|
||||
|
||||
# load sample inputs
|
||||
def generate_inputs():
|
||||
sample = torch.randn((2, 4, 64, 64), device="cuda", dtype=torch.float16)
|
||||
timestep = torch.rand(1, device="cuda", dtype=torch.float16) * 999
|
||||
encoder_hidden_states = torch.randn((2, 77, 768), device="cuda", dtype=torch.float16)
|
||||
return sample, timestep, encoder_hidden_states
|
||||
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
unet = pipeline.unet
|
||||
unet.eval()
|
||||
unet.to(memory_format=torch.channels_last) # use channels_last memory format
|
||||
unet.forward = functools.partial(unet.forward, return_dict=False) # set return_dict=False as default
|
||||
|
||||
# warmup
|
||||
for _ in range(3):
|
||||
with torch.inference_mode():
|
||||
inputs = generate_inputs()
|
||||
orig_output = unet(*inputs)
|
||||
|
||||
# trace
|
||||
print("tracing..")
|
||||
unet_traced = torch.jit.trace(unet, inputs)
|
||||
unet_traced.eval()
|
||||
print("done tracing")
|
||||
|
||||
# warmup and optimize graph
|
||||
for _ in range(5):
|
||||
with torch.inference_mode():
|
||||
inputs = generate_inputs()
|
||||
orig_output = unet_traced(*inputs)
|
||||
|
||||
# benchmarking
|
||||
with torch.inference_mode():
|
||||
for _ in range(n_experiments):
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
for _ in range(unet_runs_per_experiment):
|
||||
orig_output = unet_traced(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
print(f"unet traced inference took {time.time() - start_time:.2f} seconds")
|
||||
for _ in range(n_experiments):
|
||||
torch.cuda.synchronize()
|
||||
start_time = time.time()
|
||||
for _ in range(unet_runs_per_experiment):
|
||||
orig_output = unet(*inputs)
|
||||
torch.cuda.synchronize()
|
||||
print(f"unet inference took {time.time() - start_time:.2f} seconds")
|
||||
|
||||
# save the model
|
||||
unet_traced.save("unet_traced.pt")
|
||||
```
|
||||
|
||||
Replace the pipeline's UNet with the traced version.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class UNet2DConditionOutput:
|
||||
sample: torch.Tensor
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
use_safetensors=True,
|
||||
).to("cuda")
|
||||
|
||||
# use jitted unet
|
||||
unet_traced = torch.jit.load("unet_traced.pt")
|
||||
|
||||
# del pipeline.unet
|
||||
class TracedUNet(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.in_channels = pipe.unet.config.in_channels
|
||||
self.device = pipe.unet.device
|
||||
|
||||
def forward(self, latent_model_input, t, encoder_hidden_states):
|
||||
sample = unet_traced(latent_model_input, t, encoder_hidden_states)[0]
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
pipeline.unet = TracedUNet()
|
||||
|
||||
with torch.inference_mode():
|
||||
image = pipe([prompt] * 1, num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
## Memory-efficient attention
|
||||
|
||||
> [!TIP]
|
||||
> Memory-efficient attention optimizes for memory usage *and* [inference speed](./fp16#scaled-dot-product-attention)!
|
||||
|
||||
The Transformers attention mechanism is memory-intensive, especially for long sequences, so you can try using different and more memory-efficient attention types.
|
||||
|
||||
By default, if PyTorch >= 2.0 is installed, [scaled dot-product attention (SDPA)](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) is used. You don't need to make any additional changes to your code.
|
||||
|
||||
SDPA supports [FlashAttention](https://github.com/Dao-AILab/flash-attention) and [xFormers](https://github.com/facebookresearch/xformers) as well as a native C++ PyTorch implementation. It automatically selects the most optimal implementation based on your input.
|
||||
|
||||
You can explicitly use xFormers with the [`~ModelMixin.enable_xformers_memory_efficient_attention`] method.
|
||||
|
||||
```py
|
||||
# pip install xformers
|
||||
import torch
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
).to("cuda")
|
||||
pipeline.enable_xformers_memory_efficient_attention()
|
||||
```
|
||||
|
||||
Call [`~ModelMixin.disable_xformers_memory_efficient_attention`] to disable it.
|
||||
|
||||
```py
|
||||
pipeline.disable_xformers_memory_efficient_attention()
|
||||
```
|
||||
Diffusers supports multiple memory-efficient attention backends (FlashAttention, xFormers, SageAttention, and more) through [`~ModelMixin.set_attention_backend`]. Refer to the [Attention backends](./attention_backends) guide to learn how to switch between them.
|
||||
|
||||
@@ -23,7 +23,7 @@ pip install xformers
|
||||
> [!TIP]
|
||||
> The xFormers `pip` package requires the latest version of PyTorch. If you need to use a previous version of PyTorch, then we recommend [installing xFormers from the source](https://github.com/facebookresearch/xformers#installing-xformers).
|
||||
|
||||
After xFormers is installed, you can use `enable_xformers_memory_efficient_attention()` for faster inference and reduced memory consumption as shown in this [section](memory#memory-efficient-attention).
|
||||
After xFormers is installed, you can use it with [`~ModelMixin.set_attention_backend`] as shown in the [Attention backends](./attention_backends) guide.
|
||||
|
||||
> [!WARNING]
|
||||
> According to this [issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training (fine-tune or DreamBooth) in some GPUs. If you observe this problem, please install a development version as indicated in the issue comments.
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
sections:
|
||||
- local: using-diffusers/schedulers
|
||||
title: Load schedulers and models
|
||||
- local: using-diffusers/guiders
|
||||
title: Guiders
|
||||
|
||||
- title: Inference
|
||||
isExpanded: false
|
||||
@@ -80,8 +82,6 @@
|
||||
title: ModularPipeline
|
||||
- local: modular_diffusers/components_manager
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
|
||||
- title: Training
|
||||
isExpanded: false
|
||||
|
||||
@@ -7,7 +7,7 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Processor
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
@@ -17,7 +17,7 @@ from diffusers import (
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -44,12 +44,6 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
**LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT,
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
@@ -78,13 +72,6 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
|
||||
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
|
||||
# Decoder extra blocks
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
@@ -97,34 +84,10 @@ LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_3_VOCODER_RENAME_DICT = {
|
||||
# Handle upsamplers ("ups" --> "upsamplers") due to name clash
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
"act_post": "act_out",
|
||||
"downsample.lowpass": "downsample",
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# LTX-2.3 uses per-modality embedding projections
|
||||
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
|
||||
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
@@ -166,24 +129,23 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if ".ups." in key:
|
||||
new_key = key.replace(".ups.", ".upsamplers.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
@@ -193,19 +155,13 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = {
|
||||
".ups.": convert_ltx2_3_vocoder_upsamplers,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
@@ -269,7 +225,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
@@ -282,8 +238,6 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": False,
|
||||
"cross_attn_mod": False,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
@@ -295,8 +249,6 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": False,
|
||||
"audio_cross_attn_mod": False,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
@@ -311,62 +263,10 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": True,
|
||||
"perturbed_attn": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": True,
|
||||
"cross_attn_mod": True,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": True,
|
||||
"audio_cross_attn_mod": True,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": False,
|
||||
"perturbed_attn": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -393,7 +293,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
@@ -401,52 +301,20 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": False,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": False,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": False,
|
||||
"proj_bias": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 32,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 8,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": True,
|
||||
"audio_connector_num_attention_heads": 32,
|
||||
"audio_connector_attention_head_dim": 64,
|
||||
"audio_connector_num_layers": 8,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": True,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": True,
|
||||
"video_hidden_dim": 4096,
|
||||
"audio_hidden_dim": 2048,
|
||||
"proj_bias": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
@@ -548,7 +416,7 @@ def get_ltx2_video_vae_config(
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -567,7 +435,6 @@ def get_ltx2_video_vae_config(
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
@@ -584,44 +451,6 @@ def get_ltx2_video_vae_config(
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 1024),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 512, 1024),
|
||||
"layers_per_block": (4, 6, 4, 2, 2),
|
||||
"decoder_layers_per_block": (4, 6, 4, 2, 2),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (2, 2, 1, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -656,7 +485,7 @@ def convert_ltx2_video_vae(
|
||||
def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
@@ -679,31 +508,6 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
}, # Same config as LTX-2.0
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -736,7 +540,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) ->
|
||||
def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
@@ -745,71 +549,21 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "leaky_relu",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": False,
|
||||
"final_act_fn": "tanh",
|
||||
"final_bias": True,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1536,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [11, 4, 4, 4, 4, 4],
|
||||
"upsample_factors": [5, 2, 2, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "snakebeta",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": True,
|
||||
"antialias_ratio": 2,
|
||||
"antialias_kernel_size": 12,
|
||||
"final_act_fn": None,
|
||||
"final_bias": False,
|
||||
"bwe_in_channels": 128,
|
||||
"bwe_hidden_channels": 512,
|
||||
"bwe_out_channels": 2,
|
||||
"bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4],
|
||||
"bwe_upsample_factors": [6, 5, 2, 2, 2],
|
||||
"bwe_resnet_kernel_sizes": [3, 7, 11],
|
||||
"bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"bwe_act_fn": "snakebeta",
|
||||
"bwe_leaky_relu_negative_slope": 0.1,
|
||||
"bwe_antialias": True,
|
||||
"bwe_antialias_ratio": 2,
|
||||
"bwe_antialias_kernel_size": 12,
|
||||
"bwe_final_act_fn": None,
|
||||
"bwe_final_bias": False,
|
||||
"filter_length": 512,
|
||||
"hop_length": 80,
|
||||
"window_length": 512,
|
||||
"num_mel_channels": 64,
|
||||
"input_sampling_rate": 16000,
|
||||
"output_sampling_rate": 48000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
if version == "2.3":
|
||||
vocoder_cls = LTX2VocoderWithBWE
|
||||
else:
|
||||
vocoder_cls = LTX2Vocoder
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = vocoder_cls.from_config(diffusers_config)
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -840,18 +594,6 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": True,
|
||||
}
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
@@ -909,17 +651,13 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.removeprefix(prefix)] = param
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_prefixes = ["text_embedding_projection"]
|
||||
for param_name, param in combined_ckpt.items():
|
||||
for prefix in connector_prefixes:
|
||||
if param_name.startswith(prefix):
|
||||
# Check to make sure we're not overwriting an existing key
|
||||
if param_name not in model_state_dict:
|
||||
model_state_dict[param_name] = combined_ckpt[param_name]
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
@@ -948,7 +686,7 @@ def get_args():
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0", "2.3"],
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
@@ -1010,11 +748,6 @@ def get_args():
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_processor",
|
||||
action="store_true",
|
||||
help="Whether to add a Gemma3Processor to the pipeline for prompt enhancement.",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
@@ -1023,12 +756,6 @@ def get_args():
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument(
|
||||
"--upsample_output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path where converted upsampling pipeline should be saved",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -1060,7 +787,7 @@ def main(args):
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.connectors,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
@@ -1125,12 +852,7 @@ def main(args):
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.add_processor:
|
||||
processor = Gemma3Processor.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
processor.save_pretrained(os.path.join(args.output_path, "processor"))
|
||||
|
||||
if args.latent_upsampler or args.upsample_pipeline:
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
@@ -1144,26 +866,14 @@ def main(args):
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
is_distilled_ckpt = "distilled" in args.combined_filename
|
||||
if is_distilled_ckpt:
|
||||
# Disable dynamic shifting and terminal shift so that distilled sigmas are used as-is
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=False,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=None,
|
||||
)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
@@ -1181,12 +891,10 @@ def main(args):
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# As two diffusers pipelines cannot be in the same directory, save the upsampling pipeline to its own directory
|
||||
if args.upsample_output_path:
|
||||
upsample_output_path = args.upsample_output_path
|
||||
else:
|
||||
upsample_output_path = args.output_path
|
||||
pipe.save_pretrained(upsample_output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,6 +12,7 @@ from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
|
||||
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
|
||||
]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
@@ -92,12 +96,22 @@ def main(args):
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
in_channels = 16
|
||||
out_channels = 16
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 32
|
||||
out_channels = 32
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
if args.vae_type == "ltx2":
|
||||
sample_size = 22
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 128
|
||||
out_channels = 128
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
@@ -182,8 +196,8 @@ def main(args):
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"in_channels": in_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
@@ -235,9 +249,12 @@ def main(args):
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
if args.vae_type == "ltx2":
|
||||
vae_path = args.vae_path or "Lightricks/LTX-2"
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
else:
|
||||
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
@@ -314,7 +331,23 @@ if __name__ == "__main__":
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
default="wan",
|
||||
type=str,
|
||||
choices=["wan", "ltx2"],
|
||||
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -2156,9 +2156,6 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
# LTX-2.3
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
|
||||
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -2676,6 +2683,37 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -237,7 +237,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
|
||||
class LTX2VideoDownsampler3d(nn.Module):
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -285,11 +285,10 @@ class LTX2VideoDownsampler3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
|
||||
class LTX2VideoUpsampler3d(nn.Module):
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int | None = None,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
@@ -301,8 +300,7 @@ class LTX2VideoUpsampler3d(nn.Module):
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
|
||||
self.conv = LTX2VideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -410,7 +408,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTX2VideoDownsampler3d(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
@@ -419,7 +417,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTX2VideoDownsampler3d(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
@@ -428,7 +426,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTX2VideoDownsampler3d(
|
||||
LTXVideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
@@ -582,7 +580,6 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
upsample_type: str = "spatiotemporal",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
@@ -612,23 +609,16 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList()
|
||||
|
||||
if upsample_type == "spatial":
|
||||
upsample_stride = (1, 2, 2)
|
||||
elif upsample_type == "temporal":
|
||||
upsample_stride = (2, 1, 1)
|
||||
elif upsample_type == "spatiotemporal":
|
||||
upsample_stride = (2, 2, 2)
|
||||
|
||||
self.upsamplers.append(
|
||||
LTX2VideoUpsampler3d(
|
||||
in_channels=out_channels * upscale_factor,
|
||||
stride=upsample_stride,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
resnets = []
|
||||
@@ -726,7 +716,7 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
@@ -736,9 +726,6 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -873,27 +860,19 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: bool | tuple[bool, ...] = (False, False, False),
|
||||
inject_noise: tuple[bool, ...] = (False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[bool, ...] = (2, 2, 2),
|
||||
spatial_padding_mode: str = "reflect",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_decoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(inject_noise, bool):
|
||||
inject_noise = (inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -938,7 +917,6 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
upsample_type=upsample_type[i],
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
@@ -1080,12 +1058,11 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[int, ...] = (2, 2, 2),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
@@ -1100,16 +1077,6 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
num_decoder_blocks = len(decoder_layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
if isinstance(decoder_spatio_temporal_scaling, bool):
|
||||
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(decoder_inject_noise, bool):
|
||||
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.encoder = LTX2VideoEncoder3d(
|
||||
in_channels=in_channels,
|
||||
@@ -1131,7 +1098,6 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
upsample_type=upsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
|
||||
@@ -178,10 +178,6 @@ class LTX2AudioVideoAttnProcessor:
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -216,112 +212,6 @@ class LTX2AudioVideoAttnProcessor:
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2PerturbedAttnProcessor:
|
||||
r"""
|
||||
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
if all_perturbed is None:
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
|
||||
if all_perturbed:
|
||||
# Skip attention, use the value projection value
|
||||
hidden_states = value
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
if attn.rope_type == "interleaved":
|
||||
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_interleaved_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
elif attn.rope_type == "split":
|
||||
query = apply_split_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_split_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if perturbation_mask is not None:
|
||||
value = value.flatten(2, 3)
|
||||
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
@@ -334,7 +224,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -350,7 +240,6 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -377,12 +266,6 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if apply_gated_attention:
|
||||
# Per head gate values
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
@@ -438,10 +321,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
audio_num_attention_heads: int,
|
||||
audio_attention_head_dim,
|
||||
audio_cross_attention_dim: int,
|
||||
video_gated_attn: bool = False,
|
||||
video_cross_attn_adaln: bool = False,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_adaln: bool = False,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = True,
|
||||
@@ -449,16 +328,9 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
elementwise_affine: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
perturbed_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.perturbed_attn = perturbed_attn
|
||||
if perturbed_attn:
|
||||
attn_processor_cls = LTX2PerturbedAttnProcessor
|
||||
else:
|
||||
attn_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
|
||||
# 1. Self-Attention (video and audio)
|
||||
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn1 = LTX2Attention(
|
||||
@@ -471,8 +343,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -486,8 +356,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 2. Prompt Cross-Attention
|
||||
@@ -502,8 +370,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -517,8 +383,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
@@ -534,8 +398,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
@@ -550,8 +412,6 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 4. Feedforward layers
|
||||
@@ -562,36 +422,14 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
|
||||
|
||||
# 5. Per-Layer Modulation Parameters
|
||||
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
|
||||
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
|
||||
self.video_cross_attn_adaln = video_cross_attn_adaln
|
||||
self.audio_cross_attn_adaln = audio_cross_attn_adaln
|
||||
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
|
||||
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Prompt cross-attn (attn2) additional modulation params
|
||||
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
|
||||
if self.cross_attn_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
|
||||
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Per-layer a2v, v2a Cross-Attention mod params
|
||||
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
|
||||
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
|
||||
|
||||
@staticmethod
|
||||
def get_mod_params(
|
||||
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.shape[1], num_ada_params, -1
|
||||
)
|
||||
ada_params = ada_values.unbind(dim=2)
|
||||
return ada_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -604,181 +442,143 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
temb_ca_audio_scale_shift: torch.Tensor,
|
||||
temb_ca_gate: torch.Tensor,
|
||||
temb_ca_audio_gate: torch.Tensor,
|
||||
temb_prompt: torch.Tensor | None = None,
|
||||
temb_prompt_audio: torch.Tensor | None = None,
|
||||
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
audio_self_attention_mask: torch.Tensor | None = None,
|
||||
a2v_cross_attention_mask: torch.Tensor | None = None,
|
||||
v2a_cross_attention_mask: torch.Tensor | None = None,
|
||||
use_a2v_cross_attention: bool = True,
|
||||
use_v2a_cross_attention: bool = True,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Video and Audio Self-Attention
|
||||
# 1.1. Video Self-Attention
|
||||
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
|
||||
if self.video_cross_attn_adaln:
|
||||
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
|
||||
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
video_self_attn_args = {
|
||||
"hidden_states": norm_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": video_rotary_emb,
|
||||
"attention_mask": self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
video_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
video_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_hidden_states = self.attn1(**video_self_attn_args)
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=video_rotary_emb,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
# 1.2. Audio Self-Attention
|
||||
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_params[:6]
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
|
||||
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
|
||||
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
|
||||
batch_size, temb_audio.size(1), num_audio_ada_params, -1
|
||||
)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_values.unbind(dim=2)
|
||||
)
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
|
||||
|
||||
audio_self_attn_args = {
|
||||
"hidden_states": norm_audio_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": audio_rotary_emb,
|
||||
"attention_mask": audio_self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
audio_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
audio_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args)
|
||||
attn_audio_hidden_states = self.audio_attn1(
|
||||
hidden_states=norm_audio_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=audio_rotary_emb,
|
||||
)
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
|
||||
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
|
||||
if self.cross_attn_adaln:
|
||||
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
|
||||
shift_text_kv, scale_text_kv = video_prompt_ada_params
|
||||
|
||||
audio_prompt_ada_params = self.get_mod_params(
|
||||
self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
|
||||
)
|
||||
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
|
||||
|
||||
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text)
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.video_cross_attn_adaln:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.video_cross_attn_adaln:
|
||||
attn_hidden_states = attn_hidden_states * gate_text_q
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
# 2.2. Audio-Text Cross-Attention
|
||||
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
|
||||
if self.audio_cross_attn_adaln:
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn2(
|
||||
norm_audio_hidden_states,
|
||||
encoder_hidden_states=audio_encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
if use_a2v_cross_attention or use_v2a_cross_attention:
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
|
||||
# 3.1. Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
|
||||
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
|
||||
video_ca_scale_shift_table = (
|
||||
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
|
||||
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_gate = (
|
||||
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
|
||||
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
|
||||
a2v_gate = video_ca_gate_param[0].squeeze(2)
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
|
||||
a2v_gate = video_ca_gate[0].squeeze(2)
|
||||
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
audio_ca_ada_params = self.get_mod_params(
|
||||
audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
|
||||
)
|
||||
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
|
||||
audio_ca_scale_shift_table = (
|
||||
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
|
||||
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_gate = (
|
||||
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
|
||||
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
|
||||
v2a_gate = audio_ca_gate_param[0].squeeze(2)
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
|
||||
v2a_gate = audio_ca_gate[0].squeeze(2)
|
||||
|
||||
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
if use_a2v_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_a2v_ca_scale.squeeze(2)
|
||||
) + video_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
|
||||
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
if use_v2a_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_v2a_ca_scale.squeeze(2)
|
||||
) + video_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
|
||||
# 4. Feedforward
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
@@ -1118,8 +918,6 @@ class LTX2VideoTransformer3DModel(
|
||||
pos_embed_max_pos: int = 20,
|
||||
base_height: int = 2048,
|
||||
base_width: int = 2048,
|
||||
gated_attn: bool = False,
|
||||
cross_attn_mod: bool = False,
|
||||
audio_in_channels: int = 128, # Audio Arguments
|
||||
audio_out_channels: int | None = 128,
|
||||
audio_patch_size: int = 1,
|
||||
@@ -1131,8 +929,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_pos_embed_max_pos: int = 20,
|
||||
audio_sampling_rate: int = 16000,
|
||||
audio_hop_length: int = 160,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_mod: bool = False,
|
||||
num_layers: int = 48, # Shared arguments
|
||||
activation_fn: str = "gelu-approximate",
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
@@ -1147,8 +943,6 @@ class LTX2VideoTransformer3DModel(
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: str = "interleaved",
|
||||
use_prompt_embeddings=True,
|
||||
perturbed_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -1162,25 +956,17 @@ class LTX2VideoTransformer3DModel(
|
||||
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
|
||||
|
||||
# 2. Prompt embeddings
|
||||
if use_prompt_embeddings:
|
||||
# LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
|
||||
# 3. Timestep Modulation Params and Embedding
|
||||
self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3
|
||||
|
||||
# 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
|
||||
# time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
|
||||
video_time_emb_mod_params = 9 if cross_attn_mod else 6
|
||||
audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
|
||||
self.time_embed = LTX2AdaLayerNormSingle(
|
||||
inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False
|
||||
)
|
||||
self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
|
||||
self.audio_time_embed = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False
|
||||
audio_inner_dim, num_mod_params=6, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 3.2. Global Cross Attention Modulation Parameters
|
||||
@@ -1209,13 +995,6 @@ class LTX2VideoTransformer3DModel(
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
|
||||
|
||||
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
|
||||
if self.prompt_modulation:
|
||||
self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
|
||||
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=2, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 4. Rotary Positional Embeddings (RoPE)
|
||||
# Self-Attention
|
||||
self.rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
@@ -1292,10 +1071,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_attention_heads=audio_num_attention_heads,
|
||||
audio_attention_head_dim=audio_attention_head_dim,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
video_gated_attn=gated_attn,
|
||||
video_cross_attn_adaln=cross_attn_mod,
|
||||
audio_gated_attn=audio_gated_attn,
|
||||
audio_cross_attn_adaln=audio_cross_attn_mod,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
@@ -1303,7 +1078,6 @@ class LTX2VideoTransformer3DModel(
|
||||
eps=norm_eps,
|
||||
elementwise_affine=norm_elementwise_affine,
|
||||
rope_type=rope_type,
|
||||
perturbed_attn=perturbed_attn,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -1327,12 +1101,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: torch.LongTensor | None = None,
|
||||
sigma: torch.Tensor | None = None,
|
||||
audio_sigma: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
audio_self_attention_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
@@ -1340,10 +1110,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_frames: int | None = None,
|
||||
video_coords: torch.Tensor | None = None,
|
||||
audio_coords: torch.Tensor | None = None,
|
||||
isolate_modalities: bool = False,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -1365,19 +1131,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_timestep (`torch.Tensor`, *optional*):
|
||||
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
|
||||
params. This is only used by certain pipelines such as the I2V pipeline.
|
||||
sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in
|
||||
models such as LTX-2.3.
|
||||
audio_sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in
|
||||
models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to
|
||||
the provided `sigma` value.
|
||||
encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
|
||||
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
|
||||
self_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
|
||||
num_frames (`int`, *optional*):
|
||||
The number of latent video frames. Used if calculating the video coordinates for RoPE.
|
||||
height (`int`, *optional*):
|
||||
@@ -1395,21 +1152,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_coords (`torch.Tensor`, *optional*):
|
||||
The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
|
||||
`(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
|
||||
isolate_modalities (`bool`, *optional*, defaults to `False`):
|
||||
Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention (for all blocks). Use for modality guidance in LTX-2.3.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the
|
||||
self-attention operations by simply using the values rather than the full scaled dot-product attention
|
||||
(SDPA) operation. If `None` or empty, STG will not be applied to any block.
|
||||
perturbation_mask (`torch.Tensor`, *optional*):
|
||||
Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch
|
||||
elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is
|
||||
not supplied, will default to applying STG (perturbing) all batch elements.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
attention_kwargs (`dict[str, Any]`, *optional*):
|
||||
Optional dict of keyword args to be passed to the attention processor.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1423,7 +1165,6 @@ class LTX2VideoTransformer3DModel(
|
||||
"""
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
audio_sigma = audio_sigma if audio_sigma is not None else sigma
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
@@ -1434,32 +1175,6 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
|
||||
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
if self_attention_mask is not None and self_attention_mask.ndim == 3:
|
||||
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
|
||||
# number and positive values are mapped to their logarithm.
|
||||
dtype_finfo = torch.finfo(hidden_states.dtype)
|
||||
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
|
||||
unmasked_entries = self_attention_mask > 0
|
||||
if torch.any(unmasked_entries):
|
||||
additive_self_attn_mask[unmasked_entries] = torch.log(
|
||||
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
|
||||
).to(hidden_states.dtype)
|
||||
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
|
||||
|
||||
if audio_self_attention_mask is not None and audio_self_attention_mask.ndim == 3:
|
||||
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
|
||||
# number and positive values are mapped to their logarithm.
|
||||
dtype_finfo = torch.finfo(audio_hidden_states.dtype)
|
||||
additive_self_attn_mask = torch.full_like(
|
||||
audio_self_attention_mask, dtype_finfo.min, dtype=audio_hidden_states.dtype
|
||||
)
|
||||
unmasked_entries = audio_self_attention_mask > 0
|
||||
if torch.any(unmasked_entries):
|
||||
additive_self_attn_mask[unmasked_entries] = torch.log(
|
||||
audio_self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
|
||||
).to(audio_hidden_states.dtype)
|
||||
audio_self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Prepare RoPE positional embeddings
|
||||
@@ -1508,28 +1223,14 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
if self.prompt_modulation:
|
||||
# LTX-2.3
|
||||
temb_prompt, _ = self.prompt_adaln(
|
||||
sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
temb_prompt_audio, _ = self.audio_prompt_adaln(
|
||||
audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
|
||||
)
|
||||
temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
|
||||
temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
|
||||
else:
|
||||
temb_prompt = temb_prompt_audio = None
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten()
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
video_ca_timestep,
|
||||
timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
|
||||
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
@@ -1538,14 +1239,13 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten()
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
audio_ca_timestep,
|
||||
audio_timestep.flatten(),
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
@@ -1554,30 +1254,15 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
|
||||
|
||||
# 4. Prepare prompt embeddings (LTX-2.0)
|
||||
if self.config.use_prompt_embeddings:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
# 4. Prepare prompt embeddings
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(
|
||||
batch_size, -1, audio_hidden_states.size(-1)
|
||||
)
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
|
||||
|
||||
# 5. Run transformer blocks
|
||||
spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or []
|
||||
if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None:
|
||||
# If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements.
|
||||
perturbation_mask = torch.zeros((batch_size,))
|
||||
if perturbation_mask is not None and perturbation_mask.ndim == 1:
|
||||
perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
stg_blocks = set(spatio_temporal_guidance_blocks)
|
||||
|
||||
for block_idx, block in enumerate(self.transformer_blocks):
|
||||
block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None
|
||||
block_all_perturbed = all_perturbed if block_idx in stg_blocks else False
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1591,22 +1276,12 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_cross_attn_scale_shift,
|
||||
video_cross_attn_a2v_gate,
|
||||
audio_cross_attn_v2a_gate,
|
||||
temb_prompt,
|
||||
temb_prompt_audio,
|
||||
video_rotary_emb,
|
||||
audio_rotary_emb,
|
||||
video_cross_attn_rotary_emb,
|
||||
audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
audio_encoder_attention_mask,
|
||||
self_attention_mask,
|
||||
audio_self_attention_mask,
|
||||
None, # a2v_cross_attention_mask
|
||||
None, # v2a_cross_attention_mask
|
||||
not isolate_modalities, # use_a2v_cross_attention
|
||||
not isolate_modalities, # use_v2a_cross_attention
|
||||
block_perturbation_mask,
|
||||
block_all_perturbed,
|
||||
)
|
||||
else:
|
||||
hidden_states, audio_hidden_states = block(
|
||||
@@ -1620,22 +1295,12 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
|
||||
temb_ca_gate=video_cross_attn_a2v_gate,
|
||||
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
|
||||
temb_prompt=temb_prompt,
|
||||
temb_prompt_audio=temb_prompt_audio,
|
||||
video_rotary_emb=video_rotary_emb,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
ca_video_rotary_emb=video_cross_attn_rotary_emb,
|
||||
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
audio_encoder_attention_mask=audio_encoder_attention_mask,
|
||||
self_attention_mask=self_attention_mask,
|
||||
audio_self_attention_mask=audio_self_attention_mask,
|
||||
a2v_cross_attention_mask=None,
|
||||
v2a_cross_attention_mask=None,
|
||||
use_a2v_cross_attention=not isolate_modalities,
|
||||
use_v2a_cross_attention=not isolate_modalities,
|
||||
perturbation_mask=block_perturbation_mask,
|
||||
all_perturbed=block_all_perturbed,
|
||||
)
|
||||
|
||||
# 6. Output layers (including unpatchification)
|
||||
|
||||
@@ -309,16 +309,16 @@ class ComponentSpec:
|
||||
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
|
||||
)
|
||||
|
||||
from diffusers import AutoModel
|
||||
|
||||
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
|
||||
# As a result, it gets stored in `init_kwargs`, which are written to the config
|
||||
# during save. This causes JSON serialization to fail when saving the component.
|
||||
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
|
||||
if self.type_hint is not None and not issubclass(self.type_hint, (torch.nn.Module, AutoModel)):
|
||||
kwargs.pop("torch_dtype", None)
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
|
||||
component = AutoModel.from_pretrained(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Unable to load {self.name} without `type_hint`: {e}")
|
||||
@@ -332,12 +332,6 @@ class ComponentSpec:
|
||||
else getattr(self.type_hint, "from_pretrained")
|
||||
)
|
||||
|
||||
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
|
||||
# As a result, it gets stored in `init_kwargs`, which are written to the config
|
||||
# during save. This causes JSON serialization to fail when saving the component.
|
||||
if not issubclass(self.type_hint, torch.nn.Module):
|
||||
kwargs.pop("torch_dtype", None)
|
||||
|
||||
try:
|
||||
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
|
||||
@@ -28,7 +28,7 @@ else:
|
||||
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_ltx2_condition import LTX2ConditionPipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -11,79 +9,6 @@ from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
def per_layer_masked_mean_norm(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
"""
|
||||
Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states.
|
||||
Respects the padding of the hidden states.
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
|
||||
def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
||||
variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True)
|
||||
norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps)
|
||||
return norm_text_encoder_hidden_states
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
@@ -181,7 +106,6 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -191,9 +115,8 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
@@ -237,7 +160,6 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -266,7 +188,6 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -339,36 +260,24 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size
|
||||
text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B
|
||||
video_connector_num_attention_heads: int = 30,
|
||||
video_connector_attention_head_dim: int = 128,
|
||||
video_connector_num_layers: int = 2,
|
||||
video_connector_num_learnable_registers: int | None = 128,
|
||||
video_gated_attn: bool = False,
|
||||
audio_connector_num_attention_heads: int = 30,
|
||||
audio_connector_attention_head_dim: int = 128,
|
||||
audio_connector_num_layers: int = 2,
|
||||
audio_connector_num_learnable_registers: int | None = 128,
|
||||
audio_gated_attn: bool = False,
|
||||
connector_rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_temporal_positioning: bool = False,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
rope_type: str = "interleaved",
|
||||
per_modality_projections: bool = False,
|
||||
video_hidden_dim: int = 4096,
|
||||
audio_hidden_dim: int = 2048,
|
||||
proj_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
text_encoder_dim = caption_channels * text_proj_in_factor
|
||||
if per_modality_projections:
|
||||
self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias)
|
||||
self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias)
|
||||
else:
|
||||
self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias)
|
||||
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
@@ -379,7 +288,6 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=video_gated_attn,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
@@ -391,86 +299,26 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=audio_gated_attn,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
text_encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text
|
||||
embeddings for the LTX-2.X DiT models.
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
Args:
|
||||
text_encoder_hidden_states (`torch.Tensor`)):
|
||||
Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len,
|
||||
caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
||||
Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked
|
||||
positions.
|
||||
padding_side (`str`, *optional*, defaults to `"left"`):
|
||||
The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to
|
||||
`"left"` as this is what the default Gemma3-12B text encoder uses. Only used if
|
||||
`per_modality_projections` is `False` (LTX-2.0 models).
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False`
|
||||
(LTX-2.0 models).
|
||||
"""
|
||||
if text_encoder_hidden_states.ndim == 3:
|
||||
# Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor]
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1))
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
if self.config.per_modality_projections:
|
||||
# LTX-2.3
|
||||
norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states)
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3)
|
||||
bool_mask = attention_mask.bool().unsqueeze(-1)
|
||||
norm_text_encoder_hidden_states = torch.where(
|
||||
bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states)
|
||||
)
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
# Rescale norms with respect to video and audio dims for feature extractors
|
||||
video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels)
|
||||
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
|
||||
audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels)
|
||||
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
# Per-Modality Feature extractors
|
||||
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
|
||||
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
|
||||
else:
|
||||
# LTX-2.0
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
norm_text_encoder_hidden_states = per_layer_masked_mean_norm(
|
||||
text_hidden_states=text_encoder_hidden_states,
|
||||
sequence_lengths=sequence_lengths,
|
||||
device=text_encoder_hidden_states.device,
|
||||
padding_side=padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states)
|
||||
video_text_emb_proj = text_emb_proj
|
||||
audio_text_emb_proj = text_emb_proj
|
||||
|
||||
# Convert to additive attention mask for connectors
|
||||
text_dtype = video_text_emb_proj.dtype
|
||||
attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype)
|
||||
attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
add_attn_mask = attention_mask * torch.finfo(text_dtype).max
|
||||
|
||||
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask)
|
||||
|
||||
# Convert video attn mask to binary (multiplicative) mask and mask video text embedding
|
||||
binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64)
|
||||
binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * binary_attn_mask
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1)
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
|
||||
@@ -195,8 +195,7 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: float = 2.0,
|
||||
use_rational_resampler: bool = True,
|
||||
rational_spatial_scale: float | None = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -221,7 +220,7 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if use_rational_resampler:
|
||||
if rational_spatial_scale is not None:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
|
||||
@@ -31,7 +31,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -209,7 +209,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = ["processor"]
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -221,8 +221,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
vocoder: LTX2Vocoder,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -235,7 +234,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -270,6 +268,73 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
@@ -322,7 +387,16 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -420,46 +494,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -470,9 +504,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -516,12 +547,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
@@ -732,41 +757,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -798,14 +791,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -817,11 +803,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -860,47 +841,13 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -931,24 +878,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -981,11 +910,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -996,21 +920,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -1026,16 +939,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1057,11 +960,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1168,6 +1069,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1205,11 +1111,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1217,10 +1120,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1228,148 +1128,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
)
|
||||
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_guided_x0 = rescale_noise_cfg(
|
||||
video_guided_x0, video_cond_x0, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_guided_x0 = rescale_noise_cfg(
|
||||
audio_guided_x0, audio_cond_x0, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -254,7 +254,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
vocoder: LTX2Vocoder,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -300,6 +300,74 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -353,7 +421,16 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -464,9 +541,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=None,
|
||||
latents=None,
|
||||
audio_latents=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -523,12 +597,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
f" using the `_unpack_audio_latents` method)."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -924,41 +992,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -991,14 +1027,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[float] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float | None = None,
|
||||
num_videos_per_prompt: int | None = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -1010,7 +1039,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -1051,47 +1079,13 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the
|
||||
@@ -1123,10 +1117,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -1159,11 +1149,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -1176,21 +1161,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
latents=latents,
|
||||
audio_latents=audio_latents,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -1234,11 +1208,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1329,6 +1301,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1367,11 +1344,8 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1379,10 +1353,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1390,151 +1361,24 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
)
|
||||
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_guided_x0 = rescale_noise_cfg(
|
||||
video_guided_x0, video_cond_x0, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_guided_x0 = rescale_noise_cfg(
|
||||
audio_guided_x0, audio_cond_x0, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG
|
||||
bsz = noise_pred_video.size(0)
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
@@ -32,7 +32,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -212,7 +212,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = ["processor"]
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -224,8 +224,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
vocoder: LTX2Vocoder,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -238,7 +237,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -273,6 +271,74 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -326,7 +392,16 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -425,53 +500,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
|
||||
],
|
||||
},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -483,9 +511,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -529,12 +554,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -792,41 +811,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -859,14 +846,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -878,11 +858,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -923,47 +898,13 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
using zero terminal SNR.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -994,24 +935,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -1044,11 +967,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -1059,21 +977,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -1089,17 +996,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1121,11 +1017,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1240,6 +1134,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1278,11 +1177,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1290,10 +1186,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1301,151 +1194,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
)
|
||||
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_guided_x0 = rescale_noise_cfg(
|
||||
video_guided_x0, video_cond_x0, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_guided_x0 = rescale_noise_cfg(
|
||||
audio_guided_x0, audio_cond_x0, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred_video = self._unpack_latents(
|
||||
|
||||
@@ -8,209 +8,6 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Creates a Kaiser sinc kernel for low-pass filtering.
|
||||
|
||||
Args:
|
||||
cutoff (`float`):
|
||||
Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist
|
||||
frequency).
|
||||
half_width (`float`):
|
||||
Used to determine the Kaiser window's beta parameter.
|
||||
kernel_size:
|
||||
Size of the Kaiser window (and ultimately the Kaiser sinc kernel).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(kernel_size,)`:
|
||||
The Kaiser sinc kernel.
|
||||
"""
|
||||
delta_f = 4 * half_width
|
||||
half_size = kernel_size // 2
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
even = kernel_size % 2 == 0
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
|
||||
if cutoff == 0.0:
|
||||
filter = torch.zeros_like(time)
|
||||
else:
|
||||
time = 2 * cutoff * time
|
||||
sinc = torch.where(
|
||||
time == 0,
|
||||
torch.ones_like(time),
|
||||
torch.sin(math.pi * time) / math.pi / time,
|
||||
)
|
||||
filter = 2 * cutoff * window * sinc
|
||||
filter = filter / filter.sum()
|
||||
return filter
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
"""1D low-pass filter for antialias downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
use_padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = kernel_size or int(6 * ratio // 2) * 2
|
||||
self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1
|
||||
self.pad_right = self.kernel_size // 2
|
||||
self.use_padding = use_padding
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
cutoff = 0.5 / ratio
|
||||
half_width = 0.6 / ratio
|
||||
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
|
||||
self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
if self.use_padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels)
|
||||
return x_filtered
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
window_type: str = "kaiser",
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
if window_type == "hann":
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser sinc filter is BigVGAN default
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2
|
||||
self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2
|
||||
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode)
|
||||
low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class AntiAliasAct1d(nn.Module):
|
||||
"""
|
||||
Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples
|
||||
to avoid aliasing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_fn: str | nn.Module,
|
||||
ratio: int = 2,
|
||||
kernel_size: int = 12,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "snakebeta":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
elif act_fn == "snake":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(**kwargs)
|
||||
self.act = act_fn
|
||||
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
Implements the Snake and SnakeBeta activations, which help with learning periodic patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
alpha: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
trainable_params: bool = True,
|
||||
logscale: bool = True,
|
||||
use_beta: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.logscale = logscale
|
||||
self.use_beta = use_beta
|
||||
|
||||
self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.alpha.requires_grad = trainable_params
|
||||
if use_beta:
|
||||
self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.beta.requires_grad = trainable_params
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
|
||||
broadcast_shape = [1] * hidden_states.ndim
|
||||
broadcast_shape[channel_dim] = -1
|
||||
alpha = self.alpha.view(broadcast_shape)
|
||||
if self.use_beta:
|
||||
beta = self.beta.view(broadcast_shape)
|
||||
|
||||
if self.logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
if self.use_beta:
|
||||
beta = torch.exp(beta)
|
||||
|
||||
amplitude = beta if self.use_beta else alpha
|
||||
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -218,15 +15,12 @@ class ResBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: tuple[int, ...] = (1, 3, 5),
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
@@ -234,18 +28,6 @@ class ResBlock(nn.Module):
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
self.acts1 = nn.ModuleList()
|
||||
for _ in range(len(self.convs1)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts1.append(act)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
@@ -253,24 +35,12 @@ class ResBlock(nn.Module):
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
self.acts2 = nn.ModuleList()
|
||||
for _ in range(len(self.convs2)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts2.append(act)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2):
|
||||
xt = act1(x)
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = act2(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
@@ -291,13 +61,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = "tanh", # tanh, clamp, None
|
||||
final_bias: bool = True,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -305,9 +69,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.act_fn = act_fn
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
self.final_act_fn = final_act_fn
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
@@ -321,13 +83,6 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
supported_act_fns = ["snakebeta", "snake", "leaky_relu"]
|
||||
if self.act_fn not in supported_act_fns:
|
||||
raise ValueError(
|
||||
f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are "
|
||||
f"{supported_act_fns}."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
@@ -348,27 +103,15 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
channels=output_channels,
|
||||
kernel_size=kernel_size,
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
if act_fn == "snakebeta" or act_fn == "snake":
|
||||
# Always use antialiasing
|
||||
act_out = SnakeBeta(channels=output_channels, use_beta=True)
|
||||
self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
elif act_fn == "leaky_relu":
|
||||
# NOTE: does NOT use self.negative_slope, following the original code
|
||||
self.act_out = nn.LeakyReLU()
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias)
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -396,9 +139,7 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
if self.act_fn == "leaky_relu":
|
||||
# Other activations are inside each upsampling block
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
@@ -408,190 +149,10 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
hidden_states = self.act_out(hidden_states)
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
if self.final_act_fn == "tanh":
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
elif self.final_act_fn == "clamp":
|
||||
hidden_states = torch.clamp(hidden_states, -1, 1)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalSTFT(nn.Module):
|
||||
"""
|
||||
Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases
|
||||
multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact
|
||||
buffers should be loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.window_length = window_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples]
|
||||
|
||||
left_pad = max(0, self.window_length - self.hop_length) # causal: left-only
|
||||
waveform = F.pad(waveform, (left_pad, 0))
|
||||
|
||||
spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""
|
||||
Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be
|
||||
loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = CausalSTFT(filter_length, hop_length, window_length)
|
||||
|
||||
num_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
magnitude, phase = self.stft_fn(waveform)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the
|
||||
BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same
|
||||
architecture as the original vocoder.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1536,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4],
|
||||
upsample_factors: list[int] = [5, 2, 2, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "snakebeta",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = True,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = None,
|
||||
final_bias: bool = False,
|
||||
bwe_in_channels: int = 128,
|
||||
bwe_hidden_channels: int = 512,
|
||||
bwe_out_channels: int = 2,
|
||||
bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4],
|
||||
bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
bwe_resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
bwe_act_fn: str = "snakebeta",
|
||||
bwe_leaky_relu_negative_slope: float = 0.1,
|
||||
bwe_antialias: bool = True,
|
||||
bwe_antialias_ratio: int = 2,
|
||||
bwe_antialias_kernel_size: int = 12,
|
||||
bwe_final_act_fn: str | None = None,
|
||||
bwe_final_bias: bool = False,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocoder = LTX2Vocoder(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
out_channels=out_channels,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
upsample_factors=upsample_factors,
|
||||
resnet_kernel_sizes=resnet_kernel_sizes,
|
||||
resnet_dilations=resnet_dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
final_act_fn=final_act_fn,
|
||||
final_bias=final_bias,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
in_channels=bwe_in_channels,
|
||||
hidden_channels=bwe_hidden_channels,
|
||||
out_channels=bwe_out_channels,
|
||||
upsample_kernel_sizes=bwe_upsample_kernel_sizes,
|
||||
upsample_factors=bwe_upsample_factors,
|
||||
resnet_kernel_sizes=bwe_resnet_kernel_sizes,
|
||||
resnet_dilations=bwe_resnet_dilations,
|
||||
act_fn=bwe_act_fn,
|
||||
leaky_relu_negative_slope=bwe_leaky_relu_negative_slope,
|
||||
antialias=bwe_antialias,
|
||||
antialias_ratio=bwe_antialias_ratio,
|
||||
antialias_kernel_size=bwe_antialias_kernel_size,
|
||||
final_act_fn=bwe_final_act_fn,
|
||||
final_bias=bwe_final_bias,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=filter_length,
|
||||
hop_length=hop_length,
|
||||
window_length=window_length,
|
||||
num_mel_channels=num_mel_channels,
|
||||
)
|
||||
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate,
|
||||
window_type="hann",
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
# 1. Run stage 1 vocoder to get low sampling rate waveform
|
||||
x = self.vocoder(mel_spec)
|
||||
batch_size, num_channels, num_samples = x.shape
|
||||
|
||||
# Pad to exact multiple of hop_length for exact mel frame count
|
||||
remainder = num_samples % self.config.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# 2. Compute mel spectrogram on vocoder output
|
||||
mel, _, _, _ = self.mel_stft(x.flatten(0, 1))
|
||||
mel = mel.unflatten(0, (-1, num_channels))
|
||||
|
||||
# 3. Run bandwidth extender (BWE) on new mel spectrogram
|
||||
mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins]
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
|
||||
# 4. Residual connection with resampler
|
||||
skip = self.resampler(x)
|
||||
waveform = torch.clamp(residual + skip, -1, 1)
|
||||
output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate
|
||||
waveform = waveform[..., :output_samples]
|
||||
return waveform
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
|
||||
image_latents.device, image_latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
_latents_mean = self.vae.latents_mean
|
||||
_latents_std = self.vae.latents_std
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
_latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
_latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
else:
|
||||
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
|
||||
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
image_latents = (image_latents - latents_mean) * latents_std
|
||||
|
||||
latents[:, :, 0:1] = image_latents.to(dtype)
|
||||
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -26,9 +26,17 @@ from diffusers.models._modeling_parallel import ContextParallelConfig
|
||||
from ...testing_utils import (
|
||||
is_context_parallel,
|
||||
require_torch_multi_accelerator,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
# Device configuration mapping
|
||||
DEVICE_CONFIG = {
|
||||
"cuda": {"backend": "nccl", "module": torch.cuda},
|
||||
"xpu": {"backend": "xccl", "module": torch.xpu},
|
||||
}
|
||||
|
||||
|
||||
def _find_free_port():
|
||||
"""Find a free port on localhost."""
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
@@ -47,12 +55,17 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
# Initialize process group
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
# Create model
|
||||
model = model_class(**init_dict)
|
||||
@@ -103,10 +116,16 @@ def _custom_mesh_worker(
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
# Get device configuration
|
||||
device_config = DEVICE_CONFIG.get(torch_device, DEVICE_CONFIG["cuda"])
|
||||
backend = device_config["backend"]
|
||||
device_module = device_config["module"]
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
|
||||
|
||||
# Set device for this process
|
||||
device_module.set_device(rank)
|
||||
device = torch.device(f"{torch_device}:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
@@ -116,7 +135,7 @@ def _custom_mesh_worker(
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
torch_device, mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,49 +12,84 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 7
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -70,89 +104,57 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
encoder_hidden_states_mask[:, 2:] = 0
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
encoder_hidden_states_mask2[:, :3] = 0
|
||||
encoder_hidden_states_mask2[:, 3:] = 1
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
@@ -160,95 +162,85 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
"use_layer3d_rope": True,
|
||||
"use_additional_t_cond": True,
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
text_seq_len = 8
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
encoder_hidden_states_mask[0, 5:] = 0
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
]
|
||||
]
|
||||
|
||||
@@ -262,37 +254,113 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -300,19 +368,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -320,21 +384,18 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
mask_with_padding[:, 4:] = 0
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -342,8 +403,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
|
||||
@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
# Override to set a more lenient max diff threshold.
|
||||
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
|
||||
def test_save_load_float16(self):
|
||||
super().test_save_load_float16(expected_max_diff=0.03)
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
@@ -139,7 +139,9 @@ class HunyuanVideoImageToVideoPipelineFastTests(
|
||||
num_hidden_layers=2,
|
||||
image_size=224,
|
||||
)
|
||||
llava_text_encoder_config = LlavaConfig(vision_config, text_config, pad_token_id=100, image_token_index=101)
|
||||
llava_text_encoder_config = LlavaConfig(
|
||||
vision_config=vision_config, text_config=text_config, pad_token_id=100, image_token_index=101
|
||||
)
|
||||
|
||||
clip_text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
|
||||
@@ -171,7 +171,6 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
@@ -171,7 +171,6 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user