Compare commits

...

105 Commits

Author SHA1 Message Date
sayakpaul
af1a72ff32 resolve conflicts. 2026-01-07 17:07:57 +05:30
Daniel Gu
040c1188d9 Fix bug when calculating audio RoPE coords 2026-01-07 12:14:25 +01:00
Daniel Gu
39f7d2dda3 Add model_cpu_offload_seq for latent upsampling pipeline 2026-01-07 10:49:29 +01:00
Sayak Paul
249ae1f853 Merge branch 'main' into ltx-2-transformer 2026-01-07 14:31:12 +05:30
Daniel Gu
aa9b65d0fc When returning latents, return unpacked and denormalized latents for T2V and I2V 2026-01-07 09:04:34 +01:00
Daniel Gu
e6e7e7b26f make style and make quality 2026-01-07 08:07:24 +01:00
Daniel Gu
5e48a114b5 Remove deprecated pipeline VAE slicing/tiling methods 2026-01-07 08:06:07 +01:00
Daniel Gu
32df138fef Add latent upsample pipeline docstring and example 2026-01-07 08:03:41 +01:00
sayakpaul
4dfe509916 up 2026-01-07 12:16:52 +05:30
Daniel Gu
964f106802 Remove print statement in audio VAE 2026-01-07 06:37:24 +01:00
dg845
a17f5cb63f Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-06 21:34:57 -08:00
Daniel Gu
5269ee5067 Merge branch 'ltx-2-transformer' of github.com:huggingface/diffusers into ltx-2-transformer 2026-01-07 06:15:38 +01:00
sayakpaul
91ee2dd26a resolve conflicts 2026-01-07 10:12:20 +05:30
Sayak Paul
cc28cf76a7 Merge branch 'main' into ltx-2-transformer 2026-01-07 09:43:08 +05:30
Daniel Gu
79cf6d7ba4 Support LTX 2.0 audio VAE encoder 2026-01-07 04:16:03 +01:00
Daniel Gu
0637b549a0 Fix typo in BlurDownsample 2026-01-07 03:36:19 +01:00
Daniel Gu
8f1ddb1b1e Get latent upsampler working with video latents 2026-01-07 01:58:25 +01:00
Daniel Gu
d01a242cdb make style and make quality 2026-01-06 23:54:23 +01:00
Daniel Gu
5e0cf2b2f0 Simplify LTX 2 RoPE forward by removing coords is None logic 2026-01-06 23:32:59 +01:00
Sayak Paul
64b48c1729 Merge branch 'main' into ltx-2-transformer 2026-01-06 21:31:46 +05:30
sayakpaul
8c5ab1fd6d disable ltx2_consistency test 2026-01-06 21:31:29 +05:30
sayakpaul
61e0fb4bd8 update doc entries. 2026-01-06 21:15:47 +05:30
sayakpaul
bdcf23ec17 update docs. 2026-01-06 21:02:18 +05:30
sayakpaul
c39f1b87a4 remove args. 2026-01-06 20:52:49 +05:30
sayakpaul
57ead0b5e5 remove function map. 2026-01-06 20:48:16 +05:30
Sayak Paul
2fc578941b Merge branch 'main' into ltx-2-transformer 2026-01-06 13:51:36 +05:30
Daniel Gu
245d056c7d Add option to enable VAE tiling in upsampling test script 2026-01-06 08:07:33 +01:00
Daniel Gu
dd81242eba make style and make quality 2026-01-06 06:42:24 +01:00
Daniel Gu
ace2ee93fb Allow the I2V pipeline to accept image URLs 2026-01-06 06:40:42 +01:00
Daniel Gu
ef199118e2 Point original checkpoint to LTX 2.0 official checkpoint 2026-01-06 06:35:51 +01:00
Daniel Gu
a7d6916afc Add test script for LTX 2.0 latent upsampling 2026-01-06 05:58:31 +01:00
Daniel Gu
84c0b2fb84 Merge branch 'ltx-2-transformer' into ltx-2-latent-upsample-pipeline 2026-01-06 04:53:42 +01:00
Daniel Gu
d97fd2dd35 Add new LTX 2.0 spatial latent upsampler logic 2026-01-06 04:47:06 +01:00
sayakpaul
550eca3530 use export util funcs. 2026-01-06 09:14:38 +05:30
sayakpaul
c039c87b99 up 2026-01-06 08:09:59 +05:30
sayakpaul
9b8788cc98 resolve conflicts. 2026-01-06 08:09:37 +05:30
Sayak Paul
93a417f24a Tests for T2V and I2V (#6)
* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>
2026-01-06 08:05:30 +05:30
Daniel Gu
084490cd98 Merge branch 'ltx-2-transformer' into ltx-2-latent-upsample-pipeline 2026-01-06 03:29:38 +01:00
dg845
ce9da5d472 Merge pull request #20 from huggingface/video-export-utils-file
Add export_utils file for exporting LTX 2.0 videos with audio
2026-01-05 18:25:29 -08:00
Daniel Gu
90516804e0 Merge branch 'ltx-2-transformer' into ltx-2-latent-upsample-pipeline 2026-01-06 03:18:51 +01:00
Daniel Gu
cb50cacba5 Add export_utils file for exporting LTX 2.0 videos with audio 2026-01-06 02:17:39 +01:00
Daniel Gu
bff989110c Fix apply split RoPE shape error when reshaping x to 4D 2026-01-06 01:22:05 +01:00
Daniel Gu
2fa4f8471f When using split RoPE, make sure that the output dtype is same as input dtype 2026-01-06 00:19:39 +01:00
Sayak Paul
c5b52d6c9f address initial feedback from lightricks team (#16)
* cross_attn_timestep_scale_multiplier to 1000

* implement split rope type.

* up

* propagate rope_type to rope embed classes as well.

* up
2026-01-05 21:13:10 +05:30
Sayak Paul
0be4f31620 up (#19) 2026-01-05 21:13:01 +05:30
dg845
caae16768a Move Video and Audio Text Encoder Connectors to Transformer (#12)
* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2026-01-05 20:11:13 +05:30
Daniel Gu
fe3ba3b698 Initial implementation of LTX 2.0 latent upsampling pipeline 2026-01-02 20:18:32 +01:00
dg845
aae70b90db Merge pull request #10 from huggingface/make-scheduler-consistent
Make LTX 2.0 Scheduler `sigmas` Consistent with Original Code
2025-12-31 13:46:47 -08:00
sayakpaul
d3f10fe54e test i2v. 2025-12-31 09:36:48 +05:30
dg845
bd607b97a8 Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11) 2025-12-31 09:23:35 +05:30
Daniel Gu
6a236a27fb Merge branch 'ltx-2-transformer' into make-scheduler-consistent 2025-12-30 20:25:59 +01:00
Sayak Paul
46822c43db Add support for I2V (#8)
* start i2v.

* up

* up

* up

* up

* up

* remove uniform strategy code.

* remove unneeded code.
2025-12-30 09:06:07 +05:30
Sayak Paul
280e347814 Refactor Audio VAE to be simpler and remove helpers (#7)
* remove resolve causality axes stuff.

* remove a bunch of helpers.

* remove adjust output shape helper.

* remove the use of audiolatentshape.

* move normalization and patchify out of pipeline.

* fix

* up

* up

* Remove unpatchify and patchify ops before audio latents denormalization (#9)

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2025-12-30 08:05:56 +05:30
Daniel Gu
e1f0b7e255 Fix typo when applying scheduler fix in T2V inference script 2025-12-30 00:38:51 +01:00
Daniel Gu
581f21c431 Make LTX 2.0 scheduler more consistent with original code 2025-12-29 23:44:52 +01:00
dg845
0c41297453 Merge pull request #4 from huggingface/ltx-2-t2v-pipeline
LTX 2.0 Text-to-Video (T2V) Pipeline
2025-12-23 21:29:25 -08:00
Daniel Gu
b5891b19b1 Get LTX 2 T2V pipeline to produce reasonable outputs 2025-12-24 06:07:38 +01:00
Daniel Gu
e89d9c1951 Fix video shape error in full pipeline test script 2025-12-23 11:14:05 +01:00
Daniel Gu
f9b947651f Fix pipeline audio VAE decoding dtype bug 2025-12-23 11:03:19 +01:00
Daniel Gu
1484c43183 Improve CPU offload support 2025-12-23 10:56:32 +01:00
Daniel Gu
90edc6abc9 Fix more bugs in LTX2Pipeline.__call__ 2025-12-23 10:41:27 +01:00
Daniel Gu
a56cf23483 Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__ 2025-12-23 10:40:56 +01:00
Daniel Gu
fa7d9f77f1 Fix pipeline return bugs 2025-12-23 08:49:11 +01:00
Daniel Gu
3bf736979f Add script to test full LTX2Pipeline T2V inference 2025-12-23 08:43:37 +01:00
Daniel Gu
595f485ad8 LTX 2.0 scheduler and full pipeline conversion 2025-12-23 07:41:28 +01:00
Daniel Gu
cbb10b8dca Support num_videos_per_prompt for prompt embeddings 2025-12-23 07:01:17 +01:00
Daniel Gu
6e6ce20595 Duplicate scheduler for audio latents 2025-12-23 06:40:35 +01:00
Daniel Gu
54bfc5d617 Add Audio VAE logic to T2V pipeline 2025-12-23 03:51:22 +01:00
Daniel Gu
ae3b6e7cc2 Merge branch 'ltx-2-transformer' into ltx-2-t2v-pipeline 2025-12-23 02:59:33 +01:00
Daniel Gu
d303e2a6ff Conversion script for LTX 2.0 Audio VAE Decoder 2025-12-23 02:48:15 +01:00
Daniel Gu
5f7e43d17f Add imports for LTX 2.0 Audio VAE 2025-12-23 02:08:51 +01:00
dg845
7bb4cf76ce Merge pull request #5 from huggingface/audio-decoder
Audio decoder
2025-12-22 17:00:11 -08:00
sayakpaul
409d651bab resolve conflicts. 2025-12-22 15:59:31 +05:30
sayakpaul
8134da6a56 up 2025-12-22 15:55:29 +05:30
Sayak Paul
059999a3f7 up 2025-12-22 10:24:55 +00:00
sayakpaul
58257eb0e0 up 2025-12-22 15:45:56 +05:30
Sayak Paul
5f0f2a03f7 up 2025-12-22 10:06:39 +00:00
Daniel Gu
d0f9cdaab1 Rough initial LTX 2.0 pipeline implementation 2025-12-22 10:07:20 +01:00
Daniel Gu
0028955c37 Initial LTX 2.0 text encoder implementation 2025-12-22 10:06:01 +01:00
sayakpaul
4904fd6fa5 up 2025-12-22 13:46:58 +05:30
sayakpaul
907896d533 simplify and clean up 2025-12-22 13:41:41 +05:30
sayakpaul
e54cd6bb1d up 2025-12-22 13:03:40 +05:30
sayakpaul
f4c2435d61 init registration. 2025-12-22 12:25:36 +05:30
sayakpaul
b34ddb1736 start audio decoder. 2025-12-22 12:23:31 +05:30
Daniel Gu
6c56954fa8 Use RMSNorm implementation closer to original for LTX 2.0 video VAE 2025-12-20 02:40:38 +01:00
dg845
b1cf6ff8a9 Merge pull request #2 from huggingface/ltx-2-video-vae
LTX 2.0 Video VAE Implementation
2025-12-19 16:36:38 -08:00
dg845
8bfeb4af56 Merge pull request #3 from huggingface/ltx-2-vocoder
LTX 2.0 Vocoder Implementation
2025-12-19 16:21:31 -08:00
Daniel Gu
c6a11a5530 Initial LTX 2.0 vocoder implementation 2025-12-19 12:17:10 +01:00
Daniel Gu
a748975a7c Get diffusers implementation on par with official LTX 2.0 video VAE implementation 2025-12-19 07:02:38 +01:00
Daniel Gu
491aae08d8 Add initial LTX 2.0 video VAE tests (part 2) 2025-12-17 11:39:09 +01:00
Daniel Gu
5b950d6fef Add initial LTX 2.0 video VAE tests 2025-12-17 11:30:15 +01:00
Daniel Gu
baf23e2da3 Explicitly specify temporal and spatial VAE scale factors when converting 2025-12-17 11:14:45 +01:00
Daniel Gu
269cf7b40d Initial implementation of LTX 2.0 video VAE 2025-12-17 10:51:34 +01:00
Daniel Gu
bda3ff13db Fix LTX 2 transformer bugs so consistency test passes 2025-12-16 10:53:43 +01:00
Daniel Gu
a7bc052e89 Improve dummy inputs and add test for LTX 2 transformer consistency 2025-12-16 10:44:02 +01:00
Daniel Gu
57a8b9c330 Allow LTX 2 transformer to be loaded from local path for conversion 2025-12-16 10:38:03 +01:00
Daniel Gu
d86f89ddea Add more LTX 2 transformer audio arguments 2025-12-16 07:58:12 +01:00
Daniel Gu
a5f2d2da6c Initial script to convert LTX 2 transformer to diffusers 2025-12-15 07:09:42 +01:00
Daniel Gu
aeecc4d712 Fix LTX 2 transformer shape errors 2025-12-15 06:38:57 +01:00
Daniel Gu
5765759cd3 Get LTX 2 transformer compile tests passing 2025-12-15 03:38:34 +01:00
Daniel Gu
780fb61d32 Remove RoPE debug print statements 2025-12-13 10:37:24 +01:00
Daniel Gu
e100b8f2a3 Rename LTX 2 compile test class to have LTX2 2025-12-13 10:34:11 +01:00
Daniel Gu
980591de53 Get LTX 2 transformer tests working 2025-12-13 04:57:23 +01:00
Daniel Gu
b3096c3c9e Add tests for LTX 2 transformer model 2025-12-13 04:55:41 +01:00
Daniel Gu
aa602ac483 Initial LTX 2.0 transformer implementation 2025-12-12 07:52:33 +01:00
37 changed files with 9927 additions and 0 deletions

View File

@@ -367,6 +367,8 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -443,6 +445,10 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -678,6 +684,8 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx2
title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Audio
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Audio
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Audio
[[autodoc]] AutoencoderKLLTX2Audio
- encode
- decode
- all

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Video
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Video
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Video
[[autodoc]] AutoencoderKLLTX2Video
- decode
- encode
- all

View File

@@ -0,0 +1,26 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTX2VideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTX2VideoTransformer3DModel
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTX2VideoTransformer3DModel
[[autodoc]] LTX2VideoTransformer3DModel

View File

@@ -0,0 +1,37 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX-2
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
- all
- __call__
## LTX2ImageToVideoPipeline
[[autodoc]] LTX2ImageToVideoPipeline
- all
- __call__
## LTX2PipelineOutput
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput

View File

@@ -0,0 +1,863 @@
import argparse
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 2,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 16,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
ckpt_path = filename
_, ext = os.path.splitext(ckpt_path)
if ext in [".safetensors", ".sft"]:
state_dict = safetensors.torch.load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
return state_dict
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
)
parser.add_argument(
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
help="Version of the LTX 2.0 model",
)
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
parser.add_argument(
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
)
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="rc1/ltx-2-spatial-upscaler-x2-1.0-rc1.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
def main(args):
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any(
[args.vae, args.audio_vae, args.dit, args.vocoder, args.text_encoder, args.full_pipeline]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline:
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None:
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -0,0 +1,108 @@
import argparse
import os
import torch
from diffusers import LTX2Pipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument(
"--prompt",
type=str,
default="A video of a dog dancing to energetic electronic dance music",
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"blurry, out of focus, overexposed, underexposed, low contrast, washed out colors, excessive noise, "
"grainy texture, poor lighting, flickering, motion blur, distorted proportions, unnatural skin tones, "
"deformed facial features, asymmetrical face, missing facial features, extra limbs, disfigured hands, "
"wrong hand count, artifacts around text, inconsistent perspective, camera shake, incorrect depth of "
"field, background too sharp, background clutter, distracting reflections, harsh shadows, inconsistent "
"lighting direction, color banding, cartoonish rendering, 3D CGI look, unrealistic materials, uncanny "
"valley effect, incorrect ethnicity, wrong gender, exaggerated expressions, wrong gaze direction, "
"mismatched lip sync, silent or muted audio, distorted voice, robotic voice, echo, background noise, "
"off-sync audio,incorrect dialogue, added dialogue, repetitive speech, jittery movement, awkward "
"pauses, incorrect timing, unnatural transitions, inconsistent framing, tilted camera, flat lighting, "
"inconsistent tone, cinematic oversaturation, stylized filters, or AI artifacts."
),
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--output_dir",
type=str,
default="/home/daniel_gu/samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_sample_video.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2Pipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
pipeline.to(device=args.device)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
video, audio = pipeline(
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="np",
return_dict=False,
)
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,102 @@
import argparse
import os
import torch
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--image_path", required=True, type=str)
parser.add_argument(
"--prompt",
type=str,
default="An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot.",
)
parser.add_argument(
"--negative_prompt",
type=str,
default="shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static.",
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument(
"--output_dir",
type=str,
default="samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_sample_video.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline.to(device=args.device)
image = load_image(args.image_path)
video, audio = pipeline(
image=image,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type="np",
return_dict=False,
)
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=pipeline.vocoder.config.output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,174 @@
import argparse
import gc
import os
import torch
from diffusers import AutoencoderKLLTX2Video
from diffusers.pipelines.ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2LatentUpsamplerModel
from diffusers.pipelines.ltx2.export_utils import encode_video
from diffusers.utils import load_image
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_id", type=str, default="diffusers-internal-dev/new-ltx-model")
parser.add_argument("--revision", type=str, default="main")
parser.add_argument("--image_path", required=True, type=str)
parser.add_argument(
"--prompt",
type=str,
default=(
"An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart "
"in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in "
"slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless "
"motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep "
"darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and "
"scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground "
"dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity "
"motion, cinematic lighting, and a breath-taking, movie-like shot."
),
)
parser.add_argument(
"--negative_prompt",
type=str,
default=(
"shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion "
"artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static."
),
)
parser.add_argument("--num_inference_steps", type=int, default=40)
parser.add_argument("--height", type=int, default=512)
parser.add_argument("--width", type=int, default=768)
parser.add_argument("--num_frames", type=int, default=121)
parser.add_argument("--frame_rate", type=float, default=25.0)
parser.add_argument("--guidance_scale", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--apply_scheduler_fix", action="store_true")
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--dtype", type=str, default="bf16")
parser.add_argument("--cpu_offload", action="store_true")
parser.add_argument("--vae_tiling", action="store_true")
parser.add_argument("--use_video_latents", action="store_true")
parser.add_argument(
"--output_dir",
type=str,
default="samples",
help="Output directory for generated video",
)
parser.add_argument(
"--output_filename",
type=str,
default="ltx2_i2v_video_upsampled.mp4",
help="Filename of the exported generated video",
)
args = parser.parse_args()
args.dtype = torch.bfloat16 if args.dtype == "bf16" else torch.float32
return args
def main(args):
pipeline = LTX2ImageToVideoPipeline.from_pretrained(
args.model_id,
revision=args.revision,
torch_dtype=args.dtype,
)
if args.cpu_offload:
pipeline.enable_model_cpu_offload()
else:
pipeline.to(device=args.device)
image = load_image(args.image_path)
first_stage_output_type = "pil"
if args.use_video_latents:
first_stage_output_type = "latent"
video, audio = pipeline(
image=image,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
frame_rate=args.frame_rate,
num_inference_steps=args.num_inference_steps,
guidance_scale=args.guidance_scale,
generator=torch.Generator(device=args.device).manual_seed(args.seed),
output_type=first_stage_output_type,
return_dict=False,
)
if args.use_video_latents:
# Manually convert the audio latents to a waveform
audio = audio.to(pipeline.audio_vae.dtype)
audio = pipeline.audio_vae.decode(audio, return_dict=False)[0]
audio = pipeline.vocoder(audio)
# Get some pipeline configs for upsampling
spatial_patch_size = pipeline.transformer_spatial_patch_size
temporal_patch_size = pipeline.transformer_temporal_patch_size
# upsample_pipeline = LTX2LatentUpsamplePipeline.from_pretrained(
# args.model_id, revision=args.revision, torch_dtype=args.dtype,
# )
output_sampling_rate = pipeline.vocoder.config.output_sampling_rate
del pipeline # Otherwise there might be an OOM error?
torch.cuda.empty_cache()
gc.collect()
vae = AutoencoderKLLTX2Video.from_pretrained(
args.model_id,
subfolder="vae",
revision=args.revision,
torch_dtype=args.dtype,
)
latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
args.model_id,
subfolder="latent_upsampler",
revision=args.revision,
torch_dtype=args.dtype,
)
upsample_pipeline = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
upsample_pipeline.to(device=args.device)
if args.vae_tiling:
upsample_pipeline.vae.enable_tiling()
upsample_kwargs = {
"height": args.height,
"width": args.width,
"output_type": "np",
"return_dict": False,
}
if args.use_video_latents:
upsample_kwargs["latents"] = video
upsample_kwargs["num_frames"] = args.num_frames
upsample_kwargs["spatial_patch_size"] = spatial_patch_size
upsample_kwargs["temporal_patch_size"] = temporal_patch_size
else:
upsample_kwargs["video"] = video
video = upsample_pipeline(**upsample_kwargs)[0]
# Convert video to uint8 (but keep as NumPy array)
video = (video * 255).round().astype("uint8")
video = torch.from_numpy(video)
encode_video(
video[0],
fps=args.frame_rate,
audio=audio[0].float().cpu(),
audio_sample_rate=output_sampling_rate, # should be 24000
output_path=os.path.join(args.output_dir, args.output_filename),
)
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -0,0 +1,119 @@
import argparse
from pathlib import Path
import torch
from huggingface_hub import hf_hub_download
def download_checkpoint(
repo_id="diffusers-internal-dev/new-ltx-model",
filename="ltx-av-step-1932500-interleaved-new-vae.safetensors",
):
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
return ckpt_path
def convert_state_dict(state_dict: dict) -> dict:
converted = {}
for key, value in state_dict.items():
if not isinstance(value, torch.Tensor):
continue
new_key = key
if new_key.startswith("decoder."):
new_key = new_key[len("decoder.") :]
converted[f"decoder.{new_key}"] = value
converted["latents_mean"] = converted.pop("decoder.per_channel_statistics.mean-of-means")
converted["latents_std"] = converted.pop("decoder.per_channel_statistics.std-of-means")
return converted
def load_original_decoder(device: torch.device, dtype: torch.dtype):
from ltx_core.loader.single_gpu_model_builder import SingleGPUModelBuilder as Builder
from ltx_core.model.audio_vae.model_configurator import AUDIO_VAE_DECODER_COMFY_KEYS_FILTER
from ltx_core.model.audio_vae.model_configurator import VAEDecoderConfigurator as AudioDecoderConfigurator
checkpoint_path = download_checkpoint()
# The code below comes from `ltx-pipelines/src/ltx_pipelines/txt2vid.py`
decoder = Builder(
model_path=checkpoint_path,
model_class_configurator=AudioDecoderConfigurator,
model_sd_key_ops=AUDIO_VAE_DECODER_COMFY_KEYS_FILTER,
).build(device=device)
decoder.eval()
return decoder
def build_diffusers_decoder():
from diffusers.models.autoencoders import AutoencoderKLLTX2Audio
with torch.device("meta"):
model = AutoencoderKLLTX2Audio()
model.eval()
return model
@torch.no_grad()
def main() -> None:
parser = argparse.ArgumentParser(description="Validate LTX2 audio decoder conversion.")
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"])
parser.add_argument("--batch", type=int, default=2)
parser.add_argument("--output-path", type=Path, required=True)
args = parser.parse_args()
device = torch.device(args.device)
dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16}
dtype = dtype_map[args.dtype]
original_decoder = load_original_decoder(device, dtype)
diffusers_model = build_diffusers_decoder()
converted_state_dict = convert_state_dict(original_decoder.state_dict())
diffusers_model.load_state_dict(converted_state_dict, assign=True, strict=False)
per_channel_len = original_decoder.per_channel_statistics.get_buffer("std-of-means").numel()
latent_channels = diffusers_model.decoder.latent_channels
mel_bins_for_match = per_channel_len // latent_channels if per_channel_len % latent_channels == 0 else None
levels = len(diffusers_model.decoder.channel_multipliers)
latent_height = diffusers_model.decoder.resolution // (2 ** (levels - 1))
latent_width = mel_bins_for_match or latent_height
dummy = torch.randn(
args.batch,
diffusers_model.decoder.latent_channels,
latent_height,
latent_width,
device=device,
dtype=dtype,
generator=torch.Generator(device).manual_seed(42),
)
original_out = original_decoder(dummy)
from diffusers.pipelines.ltx2.pipeline_ltx2 import LTX2Pipeline
_, a_channels, a_time, a_freq = dummy.shape
dummy = dummy.permute(0, 2, 1, 3).reshape(-1, a_time, a_channels * a_freq)
dummy = LTX2Pipeline._denormalize_audio_latents(
dummy,
diffusers_model.latents_mean,
diffusers_model.latents_std,
)
dummy = dummy.view(-1, a_time, a_channels, a_freq).permute(0, 2, 1, 3)
diffusers_out = diffusers_model.decode(dummy).sample
torch.testing.assert_close(diffusers_out, original_out, rtol=1e-4, atol=1e-4)
max_diff = (diffusers_out - original_out).abs().max().item()
print(f"Conversion successful. Max diff: {max_diff:.6f}")
diffusers_model.to(dtype).save_pretrained(args.output_path)
print(f"Serialized model to {args.output_path}")
if __name__ == "__main__":
main()

View File

@@ -193,6 +193,8 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -236,6 +238,7 @@ else:
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -538,6 +541,9 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
@@ -939,6 +945,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -982,6 +990,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1254,6 +1263,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,

View File

@@ -41,6 +41,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -104,6 +106,7 @@ if is_torch_available():
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -153,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -212,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,804 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
LATENT_DOWNSAMPLE_FACTOR = 4
class LTX2AudioCausalConv2d(nn.Module):
"""
A causal 2D convolution that pads asymmetrically along the causal axis.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
if self.causality_axis == "none":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis in {"width", "width-compatibility"}:
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis == "height":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
self.padding = padding
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding)
return self.conv(x)
class LTX2AudioPixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
class LTX2AudioAttnBlock(nn.Module):
def __init__(
self,
in_channels: int,
norm_type: str = "group",
) -> None:
super().__init__()
self.in_channels = in_channels
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = self.norm(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
batch, channels, height, width = q.shape
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
k = k.reshape(batch, channels, height * width).contiguous()
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
attn = torch.nn.functional.softmax(attn, dim=2)
v = v.reshape(batch, channels, height * width)
attn = attn.permute(0, 2, 1).contiguous()
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
h_ = self.proj_out(h_)
return x + h_
class LTX2AudioResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: str = "group",
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if norm_type == "group":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.non_linearity = nn.SiLU()
if causality_axis is not None:
self.conv1 = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if norm_type == "group":
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.dropout = nn.Dropout(dropout)
if causality_axis is not None:
self.conv2 = LTX2AudioCausalConv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
if causality_axis is not None:
self.conv_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
if causality_axis is not None:
self.nin_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
return x + h
class LTX2AudioDownsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
if self.causality_axis == "none":
pad = (0, 1, 0, 1)
elif self.causality_axis == "width":
pad = (2, 0, 0, 1)
elif self.causality_axis == "height":
pad = (0, 1, 2, 0)
elif self.causality_axis == "width-compatibility":
pad = (1, 0, 0, 1)
else:
raise ValueError(
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
f" and `width-compatibility`."
)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# with_conv=False implies that causality_axis is "none"
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class LTX2AudioUpsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
if causality_axis is not None:
self.conv = LTX2AudioCausalConv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if self.causality_axis is None or self.causality_axis == "none":
pass
elif self.causality_axis == "height":
x = x[:, :, 1:, :]
elif self.causality_axis == "width":
x = x[:, :, :, 1:]
elif self.causality_axis == "width-compatibility":
pass
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
"""
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
):
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self._patch_size = (1, patch_size, patch_size)
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
batch, channels, time, freq = audio_latents.shape
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
batch, time, _ = audio_latents.shape
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
class LTX2AudioEncoder(nn.Module):
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels
base_resolution = resolution
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution
for level in range(self.num_resolutions):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != self.num_resolutions - 1:
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
curr_res = curr_res // 2
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
final_block_channels = block_in
z_channels = 2 * latent_channels if double_z else latent_channels
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
self.non_linearity = nn.SiLU()
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
hidden_states = self.conv_in(hidden_states)
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx, block in enumerate(stage.block):
hidden_states = block(hidden_states, temb=None)
if stage.attn:
hidden_states = stage.attn[block_idx](hidden_states)
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
hidden_states = stage.downsample(hidden_states)
hidden_states = self.mid.block_1(hidden_states, temb=None)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb=None)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class LTX2AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
convolutions.
"""
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = LTX2AudioAudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.non_linearity = nn.SiLU()
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
self.up = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
for level in reversed(range(self.num_resolutions)):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks + 1):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != 0:
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
curr_res *= 2
self.up.insert(0, stage)
final_block_channels = block_in
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
_, _, frames, mel_bins = sample.shape
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_channels = self.out_ch
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
hidden_features = self.conv_in(sample)
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
hidden_features = block(hidden_features, temb=None)
if stage.attn:
hidden_features = stage.attn[block_idx](hidden_features)
if level != 0 and hasattr(stage, "upsample"):
hidden_features = stage.upsample(hidden_features)
if self.give_pre_end:
return hidden_features
hidden = self.norm_out(hidden_features)
hidden = self.non_linearity(hidden)
decoded_output = self.conv_out(hidden)
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
_, _, current_time, current_freq = decoded_output.shape
target_time = target_frames
target_freq = target_mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
LTX2 audio VAE for encoding and decoding audio latent representations.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_channels: int = 128,
output_channels: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
norm_type: str = "pixel",
causality_axis: Optional[str] = "height",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
) -> None:
super().__init__()
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
if causality_axis not in supported_causality_axes:
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.encoder = LTX2AudioEncoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
double_z=double_z,
)
self.decoder = LTX2AudioDecoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
)
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True):
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -35,6 +35,7 @@ if is_torch_available():
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel

File diff suppressed because it is too large Load Diff

View File

@@ -290,6 +290,7 @@ else:
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -737,6 +738,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,325 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32,
):
super().__init__()
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
self.rope_type = rope_type
self.num_attention_heads = num_attention_heads
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
if self.rope_type == "interleaved":
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq = freqs.cos()
sin_freq = freqs.sin()
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
b = cos_freq.shape[0]
t = cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
rope_type: str = "interleaved",
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
rope_type=rope_type,
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int | None = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved",
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim,
base_seq_len=rope_base_seq_len,
theta=rope_theta,
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_binarize_threshold: float = -9000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2TextConnectors(ModelMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.
"""
@register_to_config
def __init__(
self,
caption_channels: int,
text_proj_in_factor: int,
video_connector_num_attention_heads: int,
video_connector_attention_head_dim: int,
video_connector_num_layers: int,
video_connector_num_learnable_registers: int | None,
audio_connector_num_attention_heads: int,
audio_connector_attention_head_dim: int,
audio_connector_num_layers: int,
audio_connector_num_learnable_registers: int | None,
connector_rope_base_seq_len: int,
rope_theta: float,
rope_double_precision: bool,
causal_temporal_positioning: bool,
rope_type: str = "interleaved",
):
super().__init__()
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
def forward(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
):
# Convert to additive attention mask, if necessary
if not additive_mask:
text_dtype = text_encoder_hidden_states.dtype
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -0,0 +1,134 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fractions import Fraction
from typing import Optional
import torch
from ...utils import is_av_available
_CAN_USE_AV = is_av_available()
if _CAN_USE_AV:
import av
else:
raise ImportError(
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -0,0 +1,285 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
RATIONAL_RESAMPLER_SCALE_MAPPING = {
0.75: (3, 4),
1.5: (3, 2),
2.0: (2, 1),
4.0: (4, 1),
}
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
class ResBlock(torch.nn.Module):
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.activation(hidden_states + residual)
return hidden_states
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
class PixelShuffleND(torch.nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
self.dims = dims
self.upscale_factors = upscale_factors
if dims not in [1, 2, 3]:
raise ValueError("dims must be 1, 2, or 3")
def forward(self, x):
if self.dims == 3:
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
return (
x.unflatten(1, (-1, *self.upscale_factors[:3]))
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
elif self.dims == 2:
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
return (
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
)
elif self.dims == 1:
# temporal: b (c p1) f h w -> b c (f p1) h w
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
if dims not in (2, 3):
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
if kernel_size < 3 or kernel_size % 2 != 1:
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
c = x.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
else:
# dims == 3: apply per-frame on H,W
b, c, f, _, _ = x.shape
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
h2, w2 = x.shape[-2:]
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class SpatialRationalResampler(torch.nn.Module):
"""
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
(integer) denominator.
"""
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = float(scale)
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
if num_denom is None:
raise ValueError(
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
)
self.num, self.den = num_denom
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expected x shape: [B * F, C, H, W]
# b, _, f, h, w = x.shape
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`, defaults to `128`):
Number of channels in the input latent
mid_channels (`int`, defaults to `512`):
Number of channels in the middle layers
num_blocks_per_stage (`int`, defaults to `4`):
Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`, defaults to `3`):
Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`, defaults to `True`):
Whether to spatially upsample the latent
temporal_upsample (`bool`, defaults to `False`):
Whether to temporally upsample the latent
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
rational_spatial_scale: Optional[float] = 2.0,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_spatial_scale is not None:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.dims == 2:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.upsampler(hidden_states)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
if self.temporal_upsample:
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states[:, :, 1:, :, :]
else:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
return hidden_states

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,432 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTX2Video
from ...utils import get_logger, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..ltx.pipeline_output import LTXPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .latent_upsampler import LTX2LatentUpsamplerModel
logger = get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2
>>> from diffusers.utils import load_image
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
... )
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> video, audio = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=768,
... height=512,
... num_frames=121,
... frame_rate=25.0,
... num_inference_steps=40,
... guidance_scale=3.0,
... output_type="pil",
... return_dict=False,
... )
>>> upsample_pipe = LTX2LatentUpsamplePipeline.from_pretrained(
... "Lightricks/LTX-Video-2", torch_dtype=torch.bfloat16
... )
>>> upsample_pipe.to("cuda")
>>> video = upsample_pipe(
... video=video,
... width=768,
... height=512,
... output_type="pil",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(video[0], fps=25.0, audio=audio[0].float().cpu(), output_path="output.mp4")
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
model_cpu_offload_seq = "vae->latent_upsampler"
def __init__(
self,
vae: AutoencoderKLLTX2Video,
latent_upsampler: LTX2LatentUpsamplerModel,
) -> None:
super().__init__()
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_frames: int = 121,
height: int = 512,
width: int = 768,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 3:
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
latents = self._unpack_latents(
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
tensor.
Args:
latent (`torch.Tensor`):
Input latents to normalize
reference_latents (`torch.Tensor`):
The reference latents providing style statistics.
factor (`float`):
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
Returns:
torch.Tensor: The transformed latent tensor
"""
result = latents.clone()
for i in range(latents.size(0)):
for c in range(latents.size(1)):
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
result = torch.lerp(latents, result, factor)
return result
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
"""
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
smooth way using a sigmoid-based compression.
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
when controlling dynamic behavior with a `compression` factor.
Args:
latents : torch.Tensor
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
compression : float
Compression strength in the range [0, 1].
- 0.0: No tone-mapping (identity transform)
- 1.0: Full compression effect
Returns:
torch.Tensor
The tone-mapped latent tensor of the same shape as input.
"""
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
scale_factor = compression * 0.75
abs_latents = torch.abs(latents)
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` can be provided.")
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
if not (0 <= tone_map_compression_ratio <= 1):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
latents: Optional[torch.Tensor] = None,
latents_normalized: bool = False,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
video (`List[PipelineImageInput]`, *optional*)
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
supplied.
height (`int`, *optional*, defaults to `512`):
The height in pixels of the input video (not the generated video, which will have a larger resolution).
width (`int`, *optional*, defaults to `768`):
The width in pixels of the input video (not the generated video, which will have a larger resolution).
num_frames (`int`, *optional*, defaults to `121`):
The number of frames in the input video.
spatial_patch_size (`int`, *optional*, defaults to `1`):
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
temporal_patch_size (`int`, *optional*, defaults to `1`):
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
necessary.
latents (`torch.Tensor`, *optional*):
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
latent_channels, latent_frames, latent_height, latent_width)`.
latents_normalized (`bool`, *optional*, defaults to `False`)
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
adain_factor (`float`, *optional*, defaults to `0.0`):
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
the full compression effect.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is the upsampled video.
"""
self.check_inputs(
video=video,
height=height,
width=width,
latents=latents,
tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
# Batched video input is not yet tested/supported. TODO: take a look later
batch_size = 1
else:
batch_size = latents.shape[0]
device = self._execution_device
if video is not None:
num_frames = len(video)
if num_frames % self.vae_temporal_compression_ratio != 1:
num_frames = (
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
)
video = video[:num_frames]
logger.warning(
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
)
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents_supplied = latents is not None
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
num_frames=num_frames,
height=height,
width=width,
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
if latents_supplied and latents_normalized:
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents_upsampled = self.latent_upsampler(latents)
if adain_factor > 0.0:
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
else:
latents = latents_upsampled
if tone_map_compression_ratio > 0.0:
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LTX2PipelineOutput(BaseOutput):
r"""
Output class for LTX pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
audio (`torch.Tensor`, `np.ndarray`):
TODO
"""
frames: torch.Tensor
audio: torch.Tensor

View File

@@ -0,0 +1,159 @@
import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
stride: int = 1,
dilations: Tuple[int, ...] = (1, 3, 5),
leaky_relu_negative_slope: float = 0.1,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = conv2(xt)
x = x + xt
return x
class LTX2Vocoder(ModelMixin, ConfigMixin):
r"""
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1024,
out_channels: int = 2,
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
upsample_factors: List[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: List[int] = [3, 7, 11],
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_negative_slope: float = 0.1,
output_sampling_rate: int = 24000,
):
super().__init__()
self.num_upsample_layers = len(upsample_kernel_sizes)
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.negative_slope = leaky_relu_negative_slope
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
)
if self.resnets_per_upsample != len(resnet_dilations):
raise ValueError(
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
self.resnets = nn.ModuleList()
input_channels = hidden_channels
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
output_channels = input_channels // 2
self.upsamplers.append(
nn.ConvTranspose1d(
input_channels, # hidden_channels // (2 ** i)
output_channels, # hidden_channels // (2 ** (i + 1))
kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
)
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
dilations=dilations,
leaky_relu_negative_slope=leaky_relu_negative_slope,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.
Args:
hidden_states (`torch.Tensor`):
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
`True`.
time_last (`bool`, *optional*, defaults to `False`):
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
Returns:
`torch.Tensor`:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Ensure that the time/frame dimension is last
if not time_last:
hidden_states = hidden_states.transpose(2, 3)
# Combine channels and frequency (mel bins) dimensions
hidden_states = hidden_states.flatten(1, 2)
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
start = i * self.resnets_per_upsample
end = (i + 1) * self.resnets_per_upsample
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
return hidden_states

View File

@@ -66,6 +66,7 @@ from .import_utils import (
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_av_available,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,

View File

@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Video(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTXVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -1147,6 +1177,21 @@ class LongCatImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1877,6 +1877,36 @@ class LongCatImagePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTX2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
def is_torch_available():
@@ -420,6 +421,10 @@ def is_kornia_available():
return _kornia_available
def is_av_available():
return _av_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

View File

@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLLTX2Audio
from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Audio
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 2, # stereo,
"output_channels": 2,
"latent_channels": 4,
"base_channels": 16,
"ch_mult": (1, 2, 4),
"resolution": 16,
"attn_resolutions": None,
"num_res_blocks": 2,
"norm_type": "pixel",
"causality_axis": "height",
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"mel_bins": 16,
"is_causal": True,
"double_z": True,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 2
num_frames = 8
num_mel_bins = 16
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
input_dict = {"sample": spectrogram}
return input_dict
@property
def input_shape(self):
return (2, 5, 16)
@property
def output_shape(self):
return (2, 5, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
def test_output(self):
super().test_output(expected_output_shape=(2, 2, 5, 16))
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLLTX2Video
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Video
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
"decoder_spatial_padding_mode": "zeros",
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
input_dict = {"sample": image}
return input_dict
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTX2VideoEncoder3d",
"LTX2VideoDecoder3d",
"LTX2VideoDownBlock3D",
"LTX2VideoMidBlock3d",
"LTX2VideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -0,0 +1,222 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import LTX2VideoTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
# Common
batch_size = 2
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
# Text
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
return {
"hidden_states": hidden_states,
"audio_hidden_states": audio_hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"num_layers": 2,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
# torch.manual_seed(seed)
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
# batch_size = 2
# num_frames = 9
# latent_frames = 2
# text_embedding_dim = 16
# text_seq_len = 16
# fps = 25.0
# sampling_rate = 16000.0
# hop_length = 160.0
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
# num_channels = 4
# latent_height = 4
# latent_width = 4
# hidden_states = torch.randn(
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify video latents (with patch_size (1, 1, 1))
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
# encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# audio_num_channels = 2
# num_mel_bins = 2
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
# audio_hidden_states = torch.randn(
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify audio latents
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
# audio_encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# inputs_dict = {
# "hidden_states": hidden_states.to(device=torch_device),
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
# "timestep": timestep,
# "num_frames": latent_frames,
# "height": latent_height,
# "width": latent_width,
# "audio_num_frames": num_frames,
# "fps": 25.0,
# }
# model = self.model_class.from_pretrained(
# "diffusers-internal-dev/dummy-ltx2",
# subfolder="transformer",
# device_map="cpu",
# )
# # torch.manual_seed(seed)
# # model = self.model_class(**init_dict)
# model.to(torch_device)
# model.eval()
# with attention_backend("native"):
# with torch.no_grad():
# output = model(**inputs_dict)
# video_output, audio_output = output.to_tuple()
# self.assertIsNotNone(video_output)
# self.assertIsNotNone(audio_output)
# # input & output have to have the same shape
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# # Check against expected slice
# # fmt: off
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# # fmt: on
# video_output_flat = video_output.cpu().flatten().float()
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
# audio_output_flat = audio_output.cpu().flatten().float()
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()

View File

View File

@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2Pipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"output_type",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
]
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)

View File

@@ -0,0 +1,241 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2ImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"image": image,
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
]
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)