Compare commits

..

9 Commits

Author SHA1 Message Date
Sayak Paul
3b334de68a Merge branch 'main' into device-map-direct 2026-01-08 12:23:39 +05:30
dg845
c10bdd9b73 Add LTX 2.0 Video Pipelines (#12915)
* Initial LTX 2.0 transformer implementation

* Add tests for LTX 2 transformer model

* Get LTX 2 transformer tests working

* Rename LTX 2 compile test class to have LTX2

* Remove RoPE debug print statements

* Get LTX 2 transformer compile tests passing

* Fix LTX 2 transformer shape errors

* Initial script to convert LTX 2 transformer to diffusers

* Add more LTX 2 transformer audio arguments

* Allow LTX 2 transformer to be loaded from local path for conversion

* Improve dummy inputs and add test for LTX 2 transformer consistency

* Fix LTX 2 transformer bugs so consistency test passes

* Initial implementation of LTX 2.0 video VAE

* Explicitly specify temporal and spatial VAE scale factors when converting

* Add initial LTX 2.0 video VAE tests

* Add initial LTX 2.0 video VAE tests (part 2)

* Get diffusers implementation on par with official LTX 2.0 video VAE implementation

* Initial LTX 2.0 vocoder implementation

* Use RMSNorm implementation closer to original for LTX 2.0 video VAE

* start audio decoder.

* init registration.

* up

* simplify and clean up

* up

* Initial LTX 2.0 text encoder implementation

* Rough initial LTX 2.0 pipeline implementation

* up

* up

* up

* up

* Add imports for LTX 2.0 Audio VAE

* Conversion script for LTX 2.0 Audio VAE Decoder

* Add Audio VAE logic to T2V pipeline

* Duplicate scheduler for audio latents

* Support num_videos_per_prompt for prompt embeddings

* LTX 2.0 scheduler and full pipeline conversion

* Add script to test full LTX2Pipeline T2V inference

* Fix pipeline return bugs

* Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__

* Fix more bugs in LTX2Pipeline.__call__

* Improve CPU offload support

* Fix pipeline audio VAE decoding dtype bug

* Fix video shape error in full pipeline test script

* Get LTX 2 T2V pipeline to produce reasonable outputs

* Make LTX 2.0 scheduler more consistent with original code

* Fix typo when applying scheduler fix in T2V inference script

* Refactor Audio VAE to be simpler and remove helpers (#7)

* remove resolve causality axes stuff.

* remove a bunch of helpers.

* remove adjust output shape helper.

* remove the use of audiolatentshape.

* move normalization and patchify out of pipeline.

* fix

* up

* up

* Remove unpatchify and patchify ops before audio latents denormalization (#9)

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Add support for I2V (#8)

* start i2v.

* up

* up

* up

* up

* up

* remove uniform strategy code.

* remove unneeded code.

* Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11)

* test i2v.

* Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>

* up (#19)

* address initial feedback from lightricks team (#16)

* cross_attn_timestep_scale_multiplier to 1000

* implement split rope type.

* up

* propagate rope_type to rope embed classes as well.

* up

* When using split RoPE, make sure that the output dtype is same as input dtype

* Fix apply split RoPE shape error when reshaping x to 4D

* Add export_utils file for exporting LTX 2.0 videos with audio

* Tests for T2V and I2V (#6)

* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* up

* use export util funcs.

* Point original checkpoint to LTX 2.0 official checkpoint

* Allow the I2V pipeline to accept image URLs

* make style and make quality

* remove function map.

* remove args.

* update docs.

* update doc entries.

* disable ltx2_consistency test

* Simplify LTX 2 RoPE forward by removing coords is None logic

* make style and make quality

* Support LTX 2.0 audio VAE encoder

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Remove print statement in audio VAE

* up

* Fix bug when calculating audio RoPE coords

* Ltx 2 latent upsample pipeline (#12922)

* Initial implementation of LTX 2.0 latent upsampling pipeline

* Add new LTX 2.0 spatial latent upsampler logic

* Add test script for LTX 2.0 latent upsampling

* Add option to enable VAE tiling in upsampling test script

* Get latent upsampler working with video latents

* Fix typo in BlurDownsample

* Add latent upsample pipeline docstring and example

* Remove deprecated pipeline VAE slicing/tiling methods

* make style and make quality

* When returning latents, return unpacked and denormalized latents for T2V and I2V

* Add model_cpu_offload_seq for latent upsampling pipeline

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* Fix latent upsampler filename in LTX 2 conversion script

* Add latent upsample pipeline to LTX 2 docs

* Add dummy objects for LTX 2 latent upsample pipeline

* Set default FPS to official LTX 2 ckpt default of 24.0

* Set default CFG scale to official LTX 2 ckpt default of 4.0

* Update LTX 2 pipeline example docstrings

* make style and make quality

* Remove LTX 2 test scripts

* Fix LTX 2 upsample pipeline example docstring

* Add logic to convert and save a LTX 2 upsampling pipeline

* Document LTX2VideoTransformer3DModel forward pass

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2026-01-07 21:24:27 -08:00
Álvaro Somoza
dab000e88b [Modular] Video for Mellon (#12924)
num_frames and videos
2026-01-07 12:35:59 -10:00
David El Malih
9fb6b89d49 Improve docstrings and type hints in scheduling_edm_euler.py (#12871)
* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters.

* refactor: Add type hints to DPM-Solver scheduler methods.
2026-01-07 11:18:00 -08:00
Sayak Paul
6fb4c99f5a Update wan.md to remove unneeded hfoptions (#12890) 2026-01-07 09:47:19 -08:00
Sayak Paul
961b9b27d3 [docs] fix torchao typo. (#12883)
fix torchao typo.
2026-01-07 09:43:02 -08:00
Sayak Paul
c61e455ce7 Merge branch 'main' into device-map-direct 2025-12-23 13:16:10 +05:30
Sayak Paul
6f5eb0a933 Merge branch 'main' into device-map-direct 2025-12-11 14:47:09 +08:00
sayakpaul
83ec2fb793 support device type device_maps to work with offloading. 2025-12-09 11:10:41 +05:30
47 changed files with 9889 additions and 1685 deletions

View File

@@ -353,8 +353,6 @@
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -369,6 +367,8 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -445,6 +445,10 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -543,8 +547,6 @@
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/glm_image
title: GLM-Image
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
@@ -682,6 +684,8 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx2
title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Audio
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Audio
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Audio
[[autodoc]] AutoencoderKLLTX2Audio
- encode
- decode
- all

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Video
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Video
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Video
[[autodoc]] AutoencoderKLLTX2Video
- decode
- encode
- all

View File

@@ -1,18 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]()
## GlmImageTransformer2DModel
[[autodoc]] GlmImageTransformer2DModel

View File

@@ -0,0 +1,26 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTX2VideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTX2VideoTransformer3DModel
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTX2VideoTransformer3DModel
[[autodoc]] LTX2VideoTransformer3DModel

View File

@@ -1,31 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# GLM-Image
> [!TIP]
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
## GlmImagePipeline
[[autodoc]] GlmImagePipeline
- all
- __call__
## GlmImagePipelineOutput
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput

View File

@@ -0,0 +1,43 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX-2
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
- all
- __call__
## LTX2ImageToVideoPipeline
[[autodoc]] LTX2ImageToVideoPipeline
- all
- __call__
## LTX2LatentUpsamplePipeline
[[autodoc]] LTX2LatentUpsamplePipeline
- all
- __call__
## LTX2PipelineOutput
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput

View File

@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.

View File

@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)

View File

@@ -0,0 +1,886 @@
import argparse
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 2,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 16,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
ckpt_path = filename
_, ext = os.path.splitext(ckpt_path)
if ext in [".safetensors", ".sft"]:
state_dict = safetensors.torch.load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
return state_dict
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
)
parser.add_argument(
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
help="Version of the LTX 2.0 model",
)
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
parser.add_argument(
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
)
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
)
parser.add_argument(
"--upsample_pipeline",
action="store_true",
help="Whether to save a latent upsampling pipeline",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
def main(args):
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any(
[
args.vae,
args.audio_vae,
args.dit,
args.vocoder,
args.text_encoder,
args.full_pipeline,
args.upsample_pipeline,
]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline or args.upsample_pipeline:
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None:
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
if not args.full_pipeline and not args.upsample_pipeline:
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.upsample_pipeline:
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
pipe.save_pretrained(
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -193,6 +193,8 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -223,7 +225,6 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -237,6 +238,7 @@ else:
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -488,7 +490,6 @@ else:
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -540,6 +541,9 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
@@ -941,6 +945,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -971,7 +977,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -985,6 +990,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1206,7 +1212,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -1258,6 +1263,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,

View File

@@ -41,6 +41,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -96,7 +98,6 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -105,6 +106,7 @@ if is_torch_available():
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -154,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -204,7 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
@@ -214,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,804 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
LATENT_DOWNSAMPLE_FACTOR = 4
class LTX2AudioCausalConv2d(nn.Module):
"""
A causal 2D convolution that pads asymmetrically along the causal axis.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
if self.causality_axis == "none":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis in {"width", "width-compatibility"}:
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis == "height":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
self.padding = padding
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding)
return self.conv(x)
class LTX2AudioPixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
class LTX2AudioAttnBlock(nn.Module):
def __init__(
self,
in_channels: int,
norm_type: str = "group",
) -> None:
super().__init__()
self.in_channels = in_channels
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = self.norm(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
batch, channels, height, width = q.shape
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
k = k.reshape(batch, channels, height * width).contiguous()
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
attn = torch.nn.functional.softmax(attn, dim=2)
v = v.reshape(batch, channels, height * width)
attn = attn.permute(0, 2, 1).contiguous()
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
h_ = self.proj_out(h_)
return x + h_
class LTX2AudioResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: str = "group",
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if norm_type == "group":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.non_linearity = nn.SiLU()
if causality_axis is not None:
self.conv1 = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if norm_type == "group":
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.dropout = nn.Dropout(dropout)
if causality_axis is not None:
self.conv2 = LTX2AudioCausalConv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
if causality_axis is not None:
self.conv_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
if causality_axis is not None:
self.nin_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
return x + h
class LTX2AudioDownsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
if self.causality_axis == "none":
pad = (0, 1, 0, 1)
elif self.causality_axis == "width":
pad = (2, 0, 0, 1)
elif self.causality_axis == "height":
pad = (0, 1, 2, 0)
elif self.causality_axis == "width-compatibility":
pad = (1, 0, 0, 1)
else:
raise ValueError(
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
f" and `width-compatibility`."
)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# with_conv=False implies that causality_axis is "none"
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class LTX2AudioUpsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
if causality_axis is not None:
self.conv = LTX2AudioCausalConv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if self.causality_axis is None or self.causality_axis == "none":
pass
elif self.causality_axis == "height":
x = x[:, :, 1:, :]
elif self.causality_axis == "width":
x = x[:, :, :, 1:]
elif self.causality_axis == "width-compatibility":
pass
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
"""
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
):
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self._patch_size = (1, patch_size, patch_size)
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
batch, channels, time, freq = audio_latents.shape
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
batch, time, _ = audio_latents.shape
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
class LTX2AudioEncoder(nn.Module):
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels
base_resolution = resolution
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution
for level in range(self.num_resolutions):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != self.num_resolutions - 1:
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
curr_res = curr_res // 2
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
final_block_channels = block_in
z_channels = 2 * latent_channels if double_z else latent_channels
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
self.non_linearity = nn.SiLU()
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
hidden_states = self.conv_in(hidden_states)
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx, block in enumerate(stage.block):
hidden_states = block(hidden_states, temb=None)
if stage.attn:
hidden_states = stage.attn[block_idx](hidden_states)
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
hidden_states = stage.downsample(hidden_states)
hidden_states = self.mid.block_1(hidden_states, temb=None)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb=None)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class LTX2AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
convolutions.
"""
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = LTX2AudioAudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.non_linearity = nn.SiLU()
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
self.up = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
for level in reversed(range(self.num_resolutions)):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks + 1):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != 0:
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
curr_res *= 2
self.up.insert(0, stage)
final_block_channels = block_in
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
_, _, frames, mel_bins = sample.shape
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_channels = self.out_ch
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
hidden_features = self.conv_in(sample)
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
hidden_features = block(hidden_features, temb=None)
if stage.attn:
hidden_features = stage.attn[block_idx](hidden_features)
if level != 0 and hasattr(stage, "upsample"):
hidden_features = stage.upsample(hidden_features)
if self.give_pre_end:
return hidden_features
hidden = self.norm_out(hidden_features)
hidden = self.non_linearity(hidden)
decoded_output = self.conv_out(hidden)
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
_, _, current_time, current_freq = decoded_output.shape
target_time = target_frames
target_freq = target_mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
LTX2 audio VAE for encoding and decoding audio latent representations.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_channels: int = 128,
output_channels: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
norm_type: str = "pixel",
causality_axis: Optional[str] = "height",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
) -> None:
super().__init__()
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
if causality_axis not in supported_causality_axes:
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.encoder = LTX2AudioEncoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
double_z=double_z,
)
self.decoder = LTX2AudioDecoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
)
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True):
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -1658,37 +1658,6 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
return conditioning
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
super().__init__()
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
def forward(
self,
timestep: torch.Tensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
hidden_dtype: torch.dtype,
) -> torch.Tensor:
timesteps_proj = self.time_proj(timestep)
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
# (B, 2 * condition_dim)
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
conditioning = timesteps_emb + condition_emb
return conditioning
class HunyuanDiTAttentionPool(nn.Module):
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6

View File

@@ -27,7 +27,6 @@ if is_torch_available():
from .transformer_easyanimate import EasyAnimateTransformer3DModel
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
@@ -36,6 +35,7 @@ if is_torch_available():
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel

View File

@@ -1,567 +0,0 @@
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import LayerNorm, RMSNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class GlmImageImageProjector(nn.Module):
def __init__(
self,
in_channels: int = 16,
hidden_size: int = 2560,
patch_size: int = 2,
):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, channel, height, width = hidden_states.shape
post_patch_height = height // self.patch_size
post_patch_width = width // self.patch_size
hidden_states = hidden_states.reshape(
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
hidden_states = self.proj(hidden_states)
return hidden_states
class GlmImageAdaLayerNormZero(nn.Module):
def __init__(self, embedding_dim: int, dim: int) -> None:
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
def forward(
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = hidden_states.dtype
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
emb = self.linear(temb)
(
shift_msa,
c_shift_msa,
scale_msa,
c_scale_msa,
gate_msa,
c_gate_msa,
shift_mlp,
c_shift_mlp,
scale_mlp,
c_scale_mlp,
gate_mlp,
c_gate_mlp,
) = emb.chunk(12, dim=1)
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
return (
hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
)
class GlmImageLayerKVCache:
"""KV cache for GlmImage model."""
def __init__(self):
self.k_cache = None
self.v_cache = None
self.mode: Optional[str] = None # "write", "read", "skip"
def store(self, k: torch.Tensor, v: torch.Tensor):
if self.k_cache is None:
self.k_cache = k
self.v_cache = v
else:
self.k_cache = torch.cat([self.k_cache, k], dim=2)
self.v_cache = torch.cat([self.v_cache, v], dim=2)
def get(self):
return self.k_cache, self.v_cache
def clear(self):
self.k_cache = None
self.v_cache = None
self.mode = None
class GlmImageKVCache:
"""Container for all layers' KV caches."""
def __init__(self, num_layers: int):
self.num_layers = num_layers
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
return self.caches[layer_idx]
def set_mode(self, mode: Optional[str]):
if mode is not None and mode not in ["write", "read", "skip"]:
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
for cache in self.caches:
cache.mode = mode
def clear(self):
for cache in self.caches:
cache.clear()
class GlmImageAttnProcessor:
"""
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
query and key vectors, but does not include spatial normalization.
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
"""
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
dtype = encoder_hidden_states.dtype
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
batch_size, image_seq_length, embed_dim = hidden_states.shape
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
# 1. QKV projections
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
# 2. QK normalization
if attn.norm_q is not None:
query = attn.norm_q(query).to(dtype=dtype)
if attn.norm_k is not None:
key = attn.norm_k(key).to(dtype=dtype)
# 3. Rotational positional embeddings applied to latent stream
if image_rotary_emb is not None:
from ..embeddings import apply_rotary_emb
query[:, :, text_seq_length:, :] = apply_rotary_emb(
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
key[:, :, text_seq_length:, :] = apply_rotary_emb(
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
)
if kv_cache is not None:
if kv_cache.mode == "write":
kv_cache.store(key, value)
elif kv_cache.mode == "read":
k_cache, v_cache = kv_cache.get()
key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key
value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value
elif kv_cache.mode == "skip":
pass
# 4. Attention
if attention_mask is not None:
text_attn_mask = attention_mask
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
text_attn_mask = text_attn_mask.float().to(query.device)
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
mix_attn_mask[:, :text_seq_length] = text_attn_mask
mix_attn_mask = mix_attn_mask.unsqueeze(2)
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
# 5. Output projection
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
encoder_hidden_states, hidden_states = hidden_states.split(
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
)
return hidden_states, encoder_hidden_states
@maybe_allow_in_graph
class GlmImageTransformerBlock(nn.Module):
def __init__(
self,
dim: int = 2560,
num_attention_heads: int = 64,
attention_head_dim: int = 40,
time_embed_dim: int = 512,
) -> None:
super().__init__()
# 1. Attention
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
out_dim=dim,
bias=True,
qk_norm="layer_norm",
elementwise_affine=False,
eps=1e-5,
processor=GlmImageAttnProcessor(),
)
# 2. Feedforward
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
kv_cache: Optional[GlmImageLayerKVCache] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Timestep conditioning
(
norm_hidden_states,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
norm_encoder_hidden_states,
c_gate_msa,
c_shift_mlp,
c_scale_mlp,
c_gate_mlp,
) = self.norm1(hidden_states, encoder_hidden_states, temb)
# 2. Attention
if attention_kwargs is None:
attention_kwargs = {}
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
attention_mask=attention_mask,
kv_cache=kv_cache,
**attention_kwargs,
)
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
# 3. Feedforward
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
1 + c_scale_mlp.unsqueeze(1)
) + c_shift_mlp.unsqueeze(1)
ff_output = self.ff(norm_hidden_states)
ff_output_context = self.ff(norm_encoder_hidden_states)
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
return hidden_states, encoder_hidden_states
class GlmImageRotaryPosEmbed(nn.Module):
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
super().__init__()
self.dim = dim
self.patch_size = patch_size
self.theta = theta
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size, num_channels, height, width = hidden_states.shape
height, width = height // self.patch_size, width // self.patch_size
dim_h, dim_w = self.dim // 2, self.dim // 2
h_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
)
w_inv_freq = 1.0 / (
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
)
h_seq = torch.arange(height)
w_seq = torch.arange(width)
freqs_h = torch.outer(h_seq, h_inv_freq)
freqs_w = torch.outer(w_seq, w_inv_freq)
# Create position matrices for height and width
# [height, 1, dim//4] and [1, width, dim//4]
freqs_h = freqs_h.unsqueeze(1)
freqs_w = freqs_w.unsqueeze(0)
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
freqs_h = freqs_h.expand(height, width, -1)
freqs_w = freqs_w.expand(height, width, -1)
# Concatenate along last dimension to get [height, width, dim//2]
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
freqs = freqs.reshape(height * width, -1)
return (freqs.cos(), freqs.sin())
class GlmImageAdaLayerNormContinuous(nn.Module):
"""
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
Linear on conditioning embedding.
"""
def __init__(
self,
embedding_dim: int,
conditioning_embedding_dim: int,
elementwise_affine: bool = True,
eps: float = 1e-5,
bias: bool = True,
norm_type: str = "layer_norm",
):
super().__init__()
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
if norm_type == "layer_norm":
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
elif norm_type == "rms_norm":
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
else:
raise ValueError(f"unknown norm_type {norm_type}")
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
# *** NO SiLU here ***
emb = self.linear(conditioning_embedding.to(x.dtype))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
r"""
Args:
patch_size (`int`, defaults to `2`):
The size of the patches to use in the patch embedding layer.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
num_layers (`int`, defaults to `30`):
The number of layers of Transformer blocks to use.
attention_head_dim (`int`, defaults to `40`):
The number of channels in each head.
num_attention_heads (`int`, defaults to `64`):
The number of heads to use for multi-head attention.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_embed_dim (`int`, defaults to `1472`):
Input dimension of text embeddings from the text encoder.
time_embed_dim (`int`, defaults to `512`):
Output dimension of timestep embeddings.
condition_dim (`int`, defaults to `256`):
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
crop_coords).
pos_embed_max_size (`int`, defaults to `128`):
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
patch_size => 128 * 8 * 2 => 2048`.
sample_size (`int`, defaults to `128`):
The base resolution of input latents. If height/width is not provided during generation, this value is used
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
"""
_supports_gradient_checkpointing = True
_no_split_modules = [
"GlmImageTransformerBlock",
"GlmImageImageProjector",
"GlmImageImageProjector",
]
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
@register_to_config
def __init__(
self,
patch_size: int = 2,
in_channels: int = 16,
out_channels: int = 16,
num_layers: int = 30,
attention_head_dim: int = 40,
num_attention_heads: int = 64,
text_embed_dim: int = 1472,
time_embed_dim: int = 512,
condition_dim: int = 256,
prior_vq_quantizer_codebook_size: int = 16384,
):
super().__init__()
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
# Each of these are sincos embeddings of shape 2 * condition_dim
pooled_projection_dim = 2 * 2 * condition_dim
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels
# 1. RoPE
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
# 2. Patch & Text-timestep embedding
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
embedding_dim=time_embed_dim,
condition_dim=condition_dim,
pooled_projection_dim=pooled_projection_dim,
timesteps_dim=time_embed_dim,
)
# 3. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
for _ in range(num_layers)
]
)
# 4. Output projection
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
prior_token_id: torch.Tensor,
prior_token_drop: torch.Tensor,
timestep: torch.LongTensor,
target_size: torch.Tensor,
crop_coords: torch.Tensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
attention_mask: Optional[torch.Tensor] = None,
kv_caches: Optional[GlmImageKVCache] = None,
image_rotary_emb: Optional[
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
] = None,
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
batch_size, num_channels, height, width = hidden_states.shape
# 1. RoPE
if image_rotary_emb is None:
image_rotary_emb = self.rope(hidden_states)
# 2. Patch & Timestep embeddings
p = self.config.patch_size
post_patch_height = height // p
post_patch_width = width // p
hidden_states = self.image_projector(hidden_states)
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
prior_embedding = self.prior_token_embedding(prior_token_id)
prior_embedding[prior_token_drop] *= 0.0
prior_hidden_states = self.prior_projector(prior_embedding)
hidden_states = hidden_states + prior_hidden_states
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
temb = F.silu(temb)
# 3. Transformer blocks
for idx, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_caches[idx] if kv_caches is not None else None,
)
else:
hidden_states, encoder_hidden_states = block(
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
attention_mask,
attention_kwargs,
kv_cache=kv_caches[idx] if kv_caches is not None else None,
)
# 4. Output norm & projection
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)
# 5. Unpatchify
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

File diff suppressed because it is too large Load Diff

View File

@@ -168,6 +168,14 @@ class MellonParam:
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
)
@classmethod
def num_frames(cls, default: int = 81) -> "MellonParam":
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
@classmethod
def videos(cls) -> "MellonParam":
return cls(name="videos", label="Videos", type="video", display="output")
@classmethod
def vae(cls) -> "MellonParam":
"""

View File

@@ -290,6 +290,7 @@ else:
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -737,6 +738,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -52,7 +52,6 @@ from .flux import (
FluxKontextPipeline,
FluxPipeline,
)
from .glm_image import GlmImagePipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -168,7 +167,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("chroma", ChromaPipeline),
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),

View File

@@ -1,882 +0,0 @@
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import re
from math import sqrt
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...loaders import CogView4LoraLoaderMixin
from ...models import AutoencoderKL, GlmImageTransformer2DModel
from ...models.transformers.transformer_glm_image import GlmImageKVCache
from ...pipelines.pipeline_utils import DiffusionPipeline
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from .pipeline_output import GlmImagePipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers import GlmImagePipeline
>>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A photo of an astronaut riding a horse on mars<sop>36 24<eop>"
>>> image = pipe(prompt).images[0]
>>> image.save("output.png")
```
"""
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
base_shift: float = 0.25,
max_shift: float = 0.75,
) -> float:
m = (image_seq_len / base_seq_len) ** 0.5
mu = m * max_shift + base_shift
return mu
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
"""
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if timesteps is not None and sigmas is not None:
if not accepts_timesteps and not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is not None and sigmas is None:
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif timesteps is None and sigmas is not None:
if not accepts_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using GLM-Image.
This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
transformer) model for image decoding.
Args:
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`T5EncoderModel`]):
Frozen text-encoder for glyph embeddings.
tokenizer (`PreTrainedTokenizer`):
Tokenizer for the text encoder.
processor (`AutoProcessor`):
Processor for the AR model to handle chat templates and tokenization.
vision_language_encoder ([`GlmImageForConditionalGeneration`]):
The AR model that generates image tokens from text prompts.
transformer ([`GlmImageTransformer2DModel`]):
A text conditioned transformer to denoise the encoded image latents (DiT).
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
_optional_components = []
model_cpu_offload_seq = "transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
tokenizer: ByT5Tokenizer,
processor: GlmImageProcessor,
text_encoder: T5EncoderModel,
vision_language_encoder: GlmImageForConditionalGeneration,
vae: AutoencoderKL,
transformer: GlmImageTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
super().__init__()
self.register_modules(
tokenizer=tokenizer,
processor=processor,
text_encoder=text_encoder,
vision_language_encoder=vision_language_encoder,
vae=vae,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer")
and self.transformer is not None
and hasattr(self.transformer.config, "sample_size")
else 128
)
def _build_image_grid_thw(
self,
token_h: int,
token_w: int,
prev_token_h: int,
prev_token_w: int,
existing_grid: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
) -> torch.Tensor:
if existing_grid is None or existing_grid.numel() == 0:
return torch.tensor(
[
[1, token_h, token_w],
[1, prev_token_h, prev_token_w],
],
device=device,
)
else:
return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0)
def _calculate_ar_generation_params(
self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool
) -> Tuple[int, int]:
"""
Calculate max_new_tokens and large_image_start_offset for AR generation.
"""
large_image_tokens = token_h * token_w
small_image_tokens = prev_token_h * prev_token_w
if is_text_to_image:
max_new_tokens = small_image_tokens + large_image_tokens + 1
large_image_start_offset = small_image_tokens
else:
max_new_tokens = large_image_tokens + 1
large_image_start_offset = 0
return max_new_tokens, large_image_start_offset
def _extract_large_image_tokens(
self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
) -> torch.Tensor:
generated_tokens = outputs[0][input_length:]
large_image_start = large_image_start_offset
large_image_end = large_image_start + large_image_tokens
return generated_tokens[large_image_start:large_image_end]
def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
"""
Upsample token IDs from d32 format to d16 format.
AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution
(each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling.
Args:
token_ids: Token IDs of shape [N] where N = token_h * token_w
token_h: Height in d32 token units
token_w: Width in d32 token units
Returns:
Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2)
"""
token_ids = token_ids.view(1, 1, token_h, token_w)
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
dtype=torch.long
)
token_ids = token_ids.view(1, -1)
return token_ids
def _build_prompt_with_shape(
self,
prompt: str,
height: int,
width: int,
is_text_to_image: bool,
factor: int = 32,
) -> Tuple[str, int, int, int, int]:
"""
Build prompt with shape info (<sop>H W<eop>) based on height and width.
Args:
prompt: The raw text prompt without shape info
height: Target image height in pixels
width: Target image width in pixels
is_text_to_image: Whether this is text-to-image (True) or image-to-image (False)
Returns:
Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w)
"""
token_h = height // factor
token_w = width // factor
ratio = token_h / token_w
prev_token_h = int(sqrt(ratio) * (factor // 2))
prev_token_w = int(sqrt(1 / ratio) * (factor // 2))
if is_text_to_image:
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop><sop>{prev_token_h} {prev_token_w}<eop>"
else:
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop>"
return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
def generate_prior_tokens(
self,
prompt: str,
height: int,
width: int,
image: Optional[List[PIL.Image.Image]] = None,
factor: int = 32,
) -> Tuple[torch.Tensor, int, int]:
"""
Generate prior tokens using the AR (vision_language_encoder) model.
Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text
prompt without <sop>...<eop> tags.
Args:
prompt: The raw text prompt (without shape info)
height: Target image height in pixels (must be divisible by factor)
width: Target image width in pixels (must be divisible by factor)
image: Optional list of condition images for image-to-image generation
factor: Token size factor (32 for d32 tokens)
Returns:
Tuple of (prior_token_ids, pixel_height, pixel_width)
- prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4]
- pixel_height: Image height in pixels (aligned to factor)
- pixel_width: Image width in pixels (aligned to factor)
"""
device = self.vision_language_encoder.device
height = (height // factor) * factor
width = (width // factor) * factor
is_text_to_image = image is None or len(image) == 0
expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape(
prompt, height, width, is_text_to_image
)
content = []
if image is not None:
for img in image:
content.append({"type": "image", "image": img})
content.append({"type": "text", "text": expanded_prompt})
messages = [{"role": "user", "content": content}]
inputs = self.processor.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
)
existing_grid = inputs.get("image_grid_thw")
inputs["image_grid_thw"] = self._build_image_grid_thw(
token_h,
token_w,
prev_h,
prev_w,
existing_grid=existing_grid if not is_text_to_image else None,
device=device,
)
max_new_tokens, large_image_offset = self._calculate_ar_generation_params(
token_h, token_w, prev_h, prev_w, is_text_to_image
)
large_image_tokens = token_h * token_w
inputs = inputs.to(device)
input_length = inputs["input_ids"].shape[-1]
outputs = self.vision_language_encoder.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
)
prior_token_ids_d32 = self._extract_large_image_tokens(
outputs, input_length, large_image_offset, large_image_tokens
)
prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w)
pixel_height = token_h * factor
pixel_width = token_w * factor
return prior_token_ids, pixel_height, pixel_width
def get_glyph_texts(self, prompt):
prompt = prompt[0] if isinstance(prompt, list) else prompt
ocr_texts = (
re.findall(r"'([^']*)'", prompt)
+ re.findall(r"“([^“”]*)”", prompt)
+ re.findall(r'"([^"]*)"', prompt)
+ re.findall(r"「([^「」]*)」", prompt)
)
return ocr_texts
def _get_glyph_embeds(
self,
prompt: Union[str, List[str]] = None,
max_sequence_length: int = 2048,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
glyph_texts = self.get_glyph_texts(prompt)
input_ids = self.tokenizer(
glyph_texts if len(glyph_texts) > 0 else [""],
max_length=max_sequence_length,
truncation=True,
).input_ids
input_ids = [
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
]
max_length = max(len(input_ids_) for input_ids_ in input_ids)
attention_mask = torch.tensor(
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
)
input_ids = torch.tensor(
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
device=device,
)
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
return glyph_embeds.to(device=device, dtype=dtype)
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
max_sequence_length: int = 2048,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_images_per_prompt (`int`, *optional*, defaults to 1):
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
max_sequence_length (`int`, defaults to `2048`):
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
seq_len = prompt_embeds.size(1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_embeds = None
if do_classifier_free_guidance:
negative_prompt = ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
seq_len = negative_prompt_embeds.size(1)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
return prompt_embeds, negative_prompt_embeds
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device)
shape = (
batch_size,
num_channels_latents,
int(height) // self.vae_scale_factor,
int(width) // self.vae_scale_factor,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def check_inputs(
self,
prompt,
height,
width,
callback_on_step_end_tensor_inputs,
prompt_embeds=None,
):
if (
height is not None
and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0
or width is not None
and width % (self.transformer.config.patch_size) != 0
):
logger.warning(
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def num_timesteps(self):
return self._num_timesteps
@property
def attention_kwargs(self):
return self._attention_kwargs
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Optional[Union[str, List[str]]] = None,
image: Optional[
Union[
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
]
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.5,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 2048,
) -> Union[GlmImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
generates a 1152x768 image.
image: Optional condition images for image-to-image generation.
height (`int`, *optional*):
The height in pixels. If not provided, derived from prompt shape info.
width (`int`, *optional*):
The width in pixels. If not provided, derived from prompt shape info.
num_inference_steps (`int`, *optional*, defaults to `50`):
The number of denoising steps for DiT.
guidance_scale (`float`, *optional*, defaults to `1.5`):
Guidance scale for classifier-free guidance.
num_images_per_prompt (`int`, *optional*, defaults to `1`):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
Random generator for reproducibility.
output_type (`str`, *optional*, defaults to `"pil"`):
Output format: "pil", "np", or "latent".
Examples:
Returns:
[`GlmImagePipelineOutput`] or `tuple`: Generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs
self.check_inputs(
prompt,
height,
width,
callback_on_step_end_tensor_inputs,
prompt_embeds,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
assert batch_size == 1, "batch_size must be 1"
device = self._execution_device
ar_condition_images = None
if image is not None:
if not isinstance(image, list):
image = [image]
ar_condition_images = []
for img in image:
if isinstance(img, PIL.Image.Image):
ar_condition_images.append(img)
elif isinstance(img, torch.Tensor):
img_np = img.cpu().numpy()
if img_np.ndim == 4:
img_np = img_np[0]
if img_np.shape[0] in [1, 3, 4]:
img_np = img_np.transpose(1, 2, 0)
if img_np.max() <= 1.0:
img_np = (img_np * 255).astype(np.uint8)
ar_condition_images.append(PIL.Image.fromarray(img_np))
else:
ar_condition_images.append(PIL.Image.fromarray(img))
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
prompt=prompt[0] if isinstance(prompt, list) else prompt,
image=ar_condition_images,
height=height,
width=width,
)
height = height or ar_height
width = width or ar_width
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
self.do_classifier_free_guidance,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
dtype=self.dtype,
)
# 4. process images
condition_images_prior_token_id = None
if image is not None:
preprocessed_condition_images = []
condition_images_prior_token_id = []
for img in image:
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
image_height = (image_height // multiple_of) * multiple_of
image_width = (image_width // multiple_of) * multiple_of
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
preprocessed_condition_images.append(img)
image = preprocessed_condition_images
# 5. Prepare latents and (optional) image kv cache
latent_channels = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size=batch_size * num_images_per_prompt,
num_channels_latents=latent_channels,
height=height,
width=width,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
if image is not None and condition_images_prior_token_id is not None:
kv_caches.set_mode("write")
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(self.vae.device, self.vae.dtype)
)
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id):
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
condition_latent = retrieve_latents(
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
)
condition_latent = (condition_latent - latents_mean) / latents_std
_ = self.transformer(
hidden_states=condition_latent,
encoder_hidden_states=empty_glyph_hiddens,
prior_token_id=condition_image_prior_token_id,
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
timestep=torch.zeros((1,), device=device),
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
crop_coords=torch.zeros((1, 2), device=device),
attention_kwargs=attention_kwargs,
kv_caches=kv_caches,
)
# 6. Prepare additional timestep conditions
target_size = (height, width)
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
# Prepare timesteps
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
self.transformer.config.patch_size**2
)
timesteps = (
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
if timesteps is None
else np.array(timesteps)
)
timesteps = timesteps.astype(np.int64).astype(np.float32)
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("base_shift", 0.25),
self.scheduler.config.get("max_shift", 0.75),
)
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
)
self._num_timesteps = len(timesteps)
# 7. Denoising loop
transformer_dtype = self.transformer.dtype
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool)
prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0]) - 1
if image is not None:
kv_caches.set_mode("read")
noise_pred_cond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_cond,
timestep=timestep,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
# perform guidance
if self.do_classifier_free_guidance:
if image is not None:
kv_caches.set_mode("skip")
noise_pred_uncond = self.transformer(
hidden_states=latent_model_input,
encoder_hidden_states=negative_prompt_embeds,
prior_token_id=prior_token_id,
prior_token_drop=prior_token_drop_uncond,
timestep=timestep,
target_size=target_size,
crop_coords=crops_coords_top_left,
attention_kwargs=attention_kwargs,
return_dict=False,
kv_caches=kv_caches,
)[0].float()
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
else:
noise_pred = noise_pred_cond
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
kv_caches.clear()
if not output_type == "latent":
latents = latents.to(self.vae.dtype)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = (
torch.tensor(self.vae.config.latents_std)
.view(1, self.vae.config.latent_channels, 1, 1)
.to(latents.device, latents.dtype)
)
latents = latents * latents_std + latents_mean
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
else:
image = latents
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return GlmImagePipelineOutput(images=image)

View File

@@ -1,21 +0,0 @@
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class GlmImagePipelineOutput(BaseOutput):
"""
Output class for CogView3 pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -11,8 +11,8 @@ from ...utils import (
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
@@ -22,15 +22,28 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_glm_image import GlmImagePipeline
from .connectors import LTX2TextConnectors
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder
else:
import sys
@@ -43,5 +56,3 @@ else:
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,325 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32,
):
super().__init__()
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
self.rope_type = rope_type
self.num_attention_heads = num_attention_heads
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
if self.rope_type == "interleaved":
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq = freqs.cos()
sin_freq = freqs.sin()
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
b = cos_freq.shape[0]
t = cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
rope_type: str = "interleaved",
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
rope_type=rope_type,
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int | None = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved",
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim,
base_seq_len=rope_base_seq_len,
theta=rope_theta,
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_binarize_threshold: float = -9000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2TextConnectors(ModelMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.
"""
@register_to_config
def __init__(
self,
caption_channels: int,
text_proj_in_factor: int,
video_connector_num_attention_heads: int,
video_connector_attention_head_dim: int,
video_connector_num_layers: int,
video_connector_num_learnable_registers: int | None,
audio_connector_num_attention_heads: int,
audio_connector_attention_head_dim: int,
audio_connector_num_layers: int,
audio_connector_num_learnable_registers: int | None,
connector_rope_base_seq_len: int,
rope_theta: float,
rope_double_precision: bool,
causal_temporal_positioning: bool,
rope_type: str = "interleaved",
):
super().__init__()
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
def forward(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
):
# Convert to additive attention mask, if necessary
if not additive_mask:
text_dtype = text_encoder_hidden_states.dtype
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -0,0 +1,134 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fractions import Fraction
from typing import Optional
import torch
from ...utils import is_av_available
_CAN_USE_AV = is_av_available()
if _CAN_USE_AV:
import av
else:
raise ImportError(
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -0,0 +1,285 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
RATIONAL_RESAMPLER_SCALE_MAPPING = {
0.75: (3, 4),
1.5: (3, 2),
2.0: (2, 1),
4.0: (4, 1),
}
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
class ResBlock(torch.nn.Module):
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.activation(hidden_states + residual)
return hidden_states
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
class PixelShuffleND(torch.nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
self.dims = dims
self.upscale_factors = upscale_factors
if dims not in [1, 2, 3]:
raise ValueError("dims must be 1, 2, or 3")
def forward(self, x):
if self.dims == 3:
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
return (
x.unflatten(1, (-1, *self.upscale_factors[:3]))
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
elif self.dims == 2:
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
return (
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
)
elif self.dims == 1:
# temporal: b (c p1) f h w -> b c (f p1) h w
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
if dims not in (2, 3):
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
if kernel_size < 3 or kernel_size % 2 != 1:
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
c = x.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
else:
# dims == 3: apply per-frame on H,W
b, c, f, _, _ = x.shape
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
h2, w2 = x.shape[-2:]
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class SpatialRationalResampler(torch.nn.Module):
"""
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
(integer) denominator.
"""
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = float(scale)
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
if num_denom is None:
raise ValueError(
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
)
self.num, self.den = num_denom
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expected x shape: [B * F, C, H, W]
# b, _, f, h, w = x.shape
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`, defaults to `128`):
Number of channels in the input latent
mid_channels (`int`, defaults to `512`):
Number of channels in the middle layers
num_blocks_per_stage (`int`, defaults to `4`):
Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`, defaults to `3`):
Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`, defaults to `True`):
Whether to spatially upsample the latent
temporal_upsample (`bool`, defaults to `False`):
Whether to temporally upsample the latent
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
rational_spatial_scale: Optional[float] = 2.0,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_spatial_scale is not None:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.dims == 2:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.upsampler(hidden_states)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
if self.temporal_upsample:
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states[:, :, 1:, :, :]
else:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
return hidden_states

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,442 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTX2Video
from ...utils import get_logger, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..ltx.pipeline_output import LTXPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .latent_upsampler import LTX2LatentUpsamplerModel
logger = get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
>>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
>>> from diffusers.utils import load_image
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
>>> pipe.enable_model_cpu_offload()
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
... )
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> frame_rate = 24.0
>>> video, audio = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=768,
... height=512,
... num_frames=121,
... frame_rate=frame_rate,
... num_inference_steps=40,
... guidance_scale=4.0,
... output_type="pil",
... return_dict=False,
... )
>>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
... )
>>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
>>> upsample_pipe.vae.enable_tiling()
>>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
>>> video = upsample_pipe(
... video=video,
... width=768,
... height=512,
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],
... fps=frame_rate,
... audio=audio[0].float().cpu(),
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
... output_path="video.mp4",
... )
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
model_cpu_offload_seq = "vae->latent_upsampler"
def __init__(
self,
vae: AutoencoderKLLTX2Video,
latent_upsampler: LTX2LatentUpsamplerModel,
) -> None:
super().__init__()
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_frames: int = 121,
height: int = 512,
width: int = 768,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 3:
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
latents = self._unpack_latents(
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
tensor.
Args:
latent (`torch.Tensor`):
Input latents to normalize
reference_latents (`torch.Tensor`):
The reference latents providing style statistics.
factor (`float`):
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
Returns:
torch.Tensor: The transformed latent tensor
"""
result = latents.clone()
for i in range(latents.size(0)):
for c in range(latents.size(1)):
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
result = torch.lerp(latents, result, factor)
return result
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
"""
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
smooth way using a sigmoid-based compression.
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
when controlling dynamic behavior with a `compression` factor.
Args:
latents : torch.Tensor
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
compression : float
Compression strength in the range [0, 1].
- 0.0: No tone-mapping (identity transform)
- 1.0: Full compression effect
Returns:
torch.Tensor
The tone-mapped latent tensor of the same shape as input.
"""
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
scale_factor = compression * 0.75
abs_latents = torch.abs(latents)
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` can be provided.")
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
if not (0 <= tone_map_compression_ratio <= 1):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
latents: Optional[torch.Tensor] = None,
latents_normalized: bool = False,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
video (`List[PipelineImageInput]`, *optional*)
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
supplied.
height (`int`, *optional*, defaults to `512`):
The height in pixels of the input video (not the generated video, which will have a larger resolution).
width (`int`, *optional*, defaults to `768`):
The width in pixels of the input video (not the generated video, which will have a larger resolution).
num_frames (`int`, *optional*, defaults to `121`):
The number of frames in the input video.
spatial_patch_size (`int`, *optional*, defaults to `1`):
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
temporal_patch_size (`int`, *optional*, defaults to `1`):
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
necessary.
latents (`torch.Tensor`, *optional*):
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
latent_channels, latent_frames, latent_height, latent_width)`.
latents_normalized (`bool`, *optional*, defaults to `False`)
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
adain_factor (`float`, *optional*, defaults to `0.0`):
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
the full compression effect.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is the upsampled video.
"""
self.check_inputs(
video=video,
height=height,
width=width,
latents=latents,
tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
# Batched video input is not yet tested/supported. TODO: take a look later
batch_size = 1
else:
batch_size = latents.shape[0]
device = self._execution_device
if video is not None:
num_frames = len(video)
if num_frames % self.vae_temporal_compression_ratio != 1:
num_frames = (
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
)
video = video[:num_frames]
logger.warning(
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
)
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents_supplied = latents is not None
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
num_frames=num_frames,
height=height,
width=width,
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
if latents_supplied and latents_normalized:
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents_upsampled = self.latent_upsampler(latents)
if adain_factor > 0.0:
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
else:
latents = latents_upsampled
if tone_map_compression_ratio > 0.0:
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LTX2PipelineOutput(BaseOutput):
r"""
Output class for LTX pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
audio (`torch.Tensor`, `np.ndarray`):
TODO
"""
frames: torch.Tensor
audio: torch.Tensor

View File

@@ -0,0 +1,159 @@
import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
stride: int = 1,
dilations: Tuple[int, ...] = (1, 3, 5),
leaky_relu_negative_slope: float = 0.1,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = conv2(xt)
x = x + xt
return x
class LTX2Vocoder(ModelMixin, ConfigMixin):
r"""
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1024,
out_channels: int = 2,
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
upsample_factors: List[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: List[int] = [3, 7, 11],
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_negative_slope: float = 0.1,
output_sampling_rate: int = 24000,
):
super().__init__()
self.num_upsample_layers = len(upsample_kernel_sizes)
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.negative_slope = leaky_relu_negative_slope
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
)
if self.resnets_per_upsample != len(resnet_dilations):
raise ValueError(
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
self.resnets = nn.ModuleList()
input_channels = hidden_channels
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
output_channels = input_channels // 2
self.upsamplers.append(
nn.ConvTranspose1d(
input_channels, # hidden_channels // (2 ** i)
output_channels, # hidden_channels // (2 ** (i + 1))
kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
)
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
dilations=dilations,
leaky_relu_negative_slope=leaky_relu_negative_slope,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.
Args:
hidden_states (`torch.Tensor`):
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
`True`.
time_last (`bool`, *optional*, defaults to `False`):
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
Returns:
`torch.Tensor`:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Ensure that the time/frame dimension is last
if not time_last:
hidden_states = hidden_states.transpose(2, 3)
# Combine channels and frequency (mel bins) dimensions
hidden_states = hidden_states.flatten(1, 2)
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
start = i * self.resnets_per_upsample
end = (i + 1) * self.resnets_per_upsample
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
return hidden_states

View File

@@ -109,7 +109,7 @@ LIBRARIES = []
for library in LOADABLE_CLASSES:
LIBRARIES.append(library)
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
logger = logging.get_logger(__name__)
@@ -462,8 +462,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
@@ -1164,7 +1163,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
self._maybe_raise_error_if_group_offload_active(raise_error=True)
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
@@ -1286,7 +1285,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
self.remove_all_hooks()
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
if is_pipeline_device_mapped:
raise ValueError(
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
@@ -2171,6 +2170,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
return True
return False
def _is_pipeline_device_mapped(self):
# We support passing `device_map="cuda"`, for example. This is helpful, in case
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
device_map = self.hf_device_map
is_device_type_map = False
if isinstance(device_map, str):
try:
torch.device(device_map)
is_device_type_map = True
except RuntimeError:
pass
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
class StableDiffusionMixin:
r"""

View File

@@ -143,7 +143,20 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
@@ -155,7 +168,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigma.atan() / math.pi * 2
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -173,13 +206,13 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -242,8 +275,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.noise_sampler = None
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -254,10 +306,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -354,7 +423,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -540,7 +612,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
)
self.noise_sampler = BrownianTreeNoiseSampler(
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
model_output,
sigma_min=self.config.sigma_min,
sigma_max=self.config.sigma_max,
seed=seed,
)
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
model_output.device
@@ -612,7 +687,18 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

View File

@@ -175,13 +175,37 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = begin_index
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by applying a logarithmic transformation.
Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.
Returns:
`torch.Tensor`:
The preconditioned noise value computed as `0.25 * log(sigma)`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -190,7 +214,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return c_noise
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -208,13 +252,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -274,8 +318,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -286,10 +349,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return sigmas
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -433,7 +513,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
sigma_t, sigma_s = (
self.sigmas[self.step_index + 1],
self.sigmas[self.step_index],
)
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
@@ -684,7 +767,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
if self.config.algorithm_type == "sde-dpmsolver++":
noise = randn_tensor(
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
model_output.shape,
generator=generator,
device=model_output.device,
dtype=model_output.dtype,
)
else:
noise = None
@@ -757,7 +843,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
return noisy_samples
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

View File

@@ -14,7 +14,7 @@
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union
import torch
@@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.
Args:
sigma_min (`float`, *optional*, defaults to 0.002):
sigma_min (`float`, *optional*, defaults to `0.002`):
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
range is [0, 10].
sigma_max (`float`, *optional*, defaults to 80.0):
sigma_max (`float`, *optional*, defaults to `80.0`):
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
range is [0.2, 80.0].
sigma_data (`float`, *optional*, defaults to 0.5):
sigma_data (`float`, *optional*, defaults to `0.5`):
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
sigma_schedule (`str`, *optional*, defaults to `karras`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, *optional*, defaults to `1000`):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `epsilon`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to 7.0):
prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
`"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
rho (`float`, *optional*, defaults to `7.0`):
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
final_sigmas_type (`str`, defaults to `"zero"`):
final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
"""
_compatibles = []
@@ -91,12 +90,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sigma_min: float = 0.002,
sigma_max: float = 80.0,
sigma_data: float = 0.5,
sigma_schedule: str = "karras",
sigma_schedule: Literal["karras", "exponential"] = "karras",
num_train_timesteps: int = 1000,
prediction_type: str = "epsilon",
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
rho: float = 7.0,
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
):
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
) -> None:
if sigma_schedule not in ["karras", "exponential"]:
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
@@ -131,26 +130,41 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
def init_noise_sigma(self) -> float:
"""
Return the standard deviation of the initial noise distribution.
Returns:
`float`:
The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
"""
return (self.config.sigma_max**2 + 1) ** 0.5
@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.
Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
method.
Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -160,12 +174,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
"""
self._begin_index = begin_index
def precondition_inputs(self, sample, sigma):
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the input sample by scaling it according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor to precondition.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The scaled input sample.
"""
c_in = self._get_conditioning_c_in(sigma)
scaled_sample = sample * c_in
return scaled_sample
def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by applying a logarithmic transformation.
Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.
Returns:
`torch.Tensor`:
The preconditioned noise value computed as `0.25 * log(sigma)`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])
@@ -173,7 +211,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
return c_noise
def precondition_outputs(self, sample, model_output, sigma):
def precondition_outputs(
self,
sample: torch.Tensor,
model_output: torch.Tensor,
sigma: Union[float, torch.Tensor],
) -> torch.Tensor:
"""
Precondition the model outputs according to the EDM formulation.
Args:
sample (`torch.Tensor`):
The input sample tensor.
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`torch.Tensor`:
The denoised sample computed by combining the skip connection and output scaling.
"""
sigma_data = self.config.sigma_data
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
@@ -190,13 +248,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
need to scale the denoising model input depending on the current timestep.
Args:
sample (`torch.Tensor`):
The input sample.
timestep (`int`, *optional*):
The input sample tensor.
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
Returns:
@@ -214,19 +272,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: int = None,
device: Union[str, torch.device] = None,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
):
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
sigmas (`torch.Tensor` or `List[float]`, *optional*):
Custom sigmas to use for the denoising process. If not defined, the default behavior when
`num_inference_steps` is passed will be used.
"""
@@ -262,8 +320,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Constructs the noise schedule of Karras et al. (2022)."""
def _compute_karras_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
Args:
ramp (`torch.Tensor`):
A tensor of values in [0, 1] representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed Karras sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -273,10 +350,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
return sigmas
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
"""Implementation closely follows k-diffusion.
def _compute_exponential_sigmas(
self,
ramp: torch.Tensor,
sigma_min: Optional[float] = None,
sigma_max: Optional[float] = None,
) -> torch.Tensor:
"""
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
Args:
ramp (`torch.Tensor`):
A tensor of values representing the interpolation positions.
sigma_min (`float`, *optional*):
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
sigma_max (`float`, *optional*):
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
Returns:
`torch.Tensor`:
The computed exponential sigma schedule.
"""
sigma_min = sigma_min or self.config.sigma_min
sigma_max = sigma_max or self.config.sigma_max
@@ -342,32 +436,38 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
pred_original_sample: Optional[torch.Tensor] = None,
) -> Union[EDMEulerSchedulerOutput, Tuple]:
) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
"""
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`float`):
The direct output from the learned diffusion model.
timestep (`float` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
s_churn (`float`):
s_tmin (`float`):
s_tmax (`float`):
s_noise (`float`, defaults to 1.0):
s_churn (`float`, *optional*, defaults to `0.0`):
The amount of stochasticity to add at each step. Higher values add more noise.
s_tmin (`float`, *optional*, defaults to `0.0`):
The minimum sigma threshold below which no noise is added.
s_tmax (`float`, *optional*, defaults to `float("inf")`):
The maximum sigma threshold above which no noise is added.
s_noise (`float`, *optional*, defaults to `1.0`):
Scaling factor for noise added to the sample.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
A random number generator for reproducibility.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
pred_original_sample (`torch.Tensor`, *optional*):
The predicted denoised sample from a previous step. If provided, skips recomputation.
Returns:
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the sample tensor.
[`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
second element is the predicted original sample tensor.
"""
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
@@ -399,7 +499,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
if gamma > 0:
noise = randn_tensor(
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
model_output.shape,
dtype=model_output.dtype,
device=model_output.device,
generator=generator,
)
eps = noise * s_noise
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
@@ -478,9 +581,20 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
noisy_samples = original_samples + noise * sigma
return noisy_samples
def _get_conditioning_c_in(self, sigma):
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
"""
Compute the input conditioning factor for the EDM formulation.
Args:
sigma (`float` or `torch.Tensor`):
The current sigma (noise level) value.
Returns:
`float` or `torch.Tensor`:
The input conditioning factor `c_in`.
"""
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -66,6 +66,7 @@ from .import_utils import (
is_accelerate_version,
is_aiter_available,
is_aiter_version,
is_av_available,
is_better_profanity_available,
is_bitsandbytes_available,
is_bitsandbytes_version,

View File

@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Video(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTXVideo(metaclass=DummyObject):
_backends = ["torch"]
@@ -967,21 +997,6 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class GlmImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -1162,6 +1177,21 @@ class LongCatImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LTXVideoTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1877,6 +1877,51 @@ class LongCatImagePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTX2LatentUpsamplePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTX2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LTXConditionPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
_aiter_available, _aiter_version = _is_package_available("aiter")
_kornia_available, _kornia_version = _is_package_available("kornia")
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
_av_available, _av_version = _is_package_available("av")
def is_torch_available():
@@ -420,6 +421,10 @@ def is_kornia_available():
return _kornia_available
def is_av_available():
return _av_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the

View File

@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLLTX2Audio
from ...testing_utils import (
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Audio
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 2, # stereo,
"output_channels": 2,
"latent_channels": 4,
"base_channels": 16,
"ch_mult": (1, 2, 4),
"resolution": 16,
"attn_resolutions": None,
"num_res_blocks": 2,
"norm_type": "pixel",
"causality_axis": "height",
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"mel_bins": 16,
"is_causal": True,
"double_z": True,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 2
num_frames = 8
num_mel_bins = 16
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
input_dict = {"sample": spectrogram}
return input_dict
@property
def input_shape(self):
return (2, 5, 16)
@property
def output_shape(self):
return (2, 5, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
def test_output(self):
super().test_output(expected_output_shape=(2, 2, 5, 16))
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -0,0 +1,103 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLLTX2Video
from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLLTX2Video
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_ltx_video_config(self):
return {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 8,
"block_out_channels": (8, 8, 8, 8),
"decoder_block_out_channels": (16, 32, 64),
"layers_per_block": (1, 1, 1, 1, 1),
"decoder_layers_per_block": (1, 1, 1, 1),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 1,
"patch_size_t": 1,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
"decoder_spatial_padding_mode": "zeros",
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
input_dict = {"sample": image}
return input_dict
@property
def input_shape(self):
return (3, 9, 16, 16)
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_ltx_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"LTX2VideoEncoder3d",
"LTX2VideoDecoder3d",
"LTX2VideoDownBlock3D",
"LTX2VideoMidBlock3d",
"LTX2VideoUpBlock3d",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
def test_forward_with_norm_groups(self):
pass

View File

@@ -0,0 +1,222 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import LTX2VideoTransformer3DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
enable_full_determinism()
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
# Common
batch_size = 2
# Video
num_frames = 2
num_channels = 4
height = 16
width = 16
# Audio
audio_num_frames = 9
audio_num_channels = 2
num_mel_bins = 2
# Text
embedding_dim = 16
sequence_length = 16
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
torch_device
)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
return {
"hidden_states": hidden_states,
"audio_hidden_states": audio_hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"audio_encoder_hidden_states": audio_encoder_hidden_states,
"timestep": timestep,
"encoder_attention_mask": encoder_attention_mask,
"num_frames": num_frames,
"height": height,
"width": width,
"audio_num_frames": audio_num_frames,
"fps": 25.0,
}
@property
def input_shape(self):
return (512, 4)
@property
def output_shape(self):
return (512, 4)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"num_layers": 2,
"qk_norm": "rms_norm_across_heads",
"caption_channels": 16,
"rope_double_precision": False,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"LTX2VideoTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
# torch.manual_seed(seed)
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
# batch_size = 2
# num_frames = 9
# latent_frames = 2
# text_embedding_dim = 16
# text_seq_len = 16
# fps = 25.0
# sampling_rate = 16000.0
# hop_length = 160.0
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
# num_channels = 4
# latent_height = 4
# latent_width = 4
# hidden_states = torch.randn(
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify video latents (with patch_size (1, 1, 1))
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
# encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# audio_num_channels = 2
# num_mel_bins = 2
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
# audio_hidden_states = torch.randn(
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# # Patchify audio latents
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
# audio_encoder_hidden_states = torch.randn(
# (batch_size, text_seq_len, text_embedding_dim),
# generator=torch.manual_seed(seed),
# dtype=dtype,
# device="cpu",
# )
# inputs_dict = {
# "hidden_states": hidden_states.to(device=torch_device),
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
# "timestep": timestep,
# "num_frames": latent_frames,
# "height": latent_height,
# "width": latent_width,
# "audio_num_frames": num_frames,
# "fps": 25.0,
# }
# model = self.model_class.from_pretrained(
# "diffusers-internal-dev/dummy-ltx2",
# subfolder="transformer",
# device_map="cpu",
# )
# # torch.manual_seed(seed)
# # model = self.model_class(**init_dict)
# model.to(torch_device)
# model.eval()
# with attention_backend("native"):
# with torch.no_grad():
# output = model(**inputs_dict)
# video_output, audio_output = output.to_tuple()
# self.assertIsNotNone(video_output)
# self.assertIsNotNone(audio_output)
# # input & output have to have the same shape
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
# # Check against expected slice
# # fmt: off
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
# # fmt: on
# video_output_flat = video_output.cpu().flatten().float()
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
# audio_output_flat = audio_output.cpu().flatten().float()
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = LTX2VideoTransformer3DModel
def prepare_init_args_and_inputs_for_common(self):
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()

View File

View File

@@ -0,0 +1,239 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2Pipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"output_type",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
]
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)

View File

@@ -0,0 +1,241 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2ImageToVideoPipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2TextConnectors
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
from ...testing_utils import enable_full_determinism
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = LTX2ImageToVideoPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"audio_latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_attention_slicing = False
test_xformers_attention = False
supports_dduf = False
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
def get_dummy_components(self):
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
torch.manual_seed(0)
transformer = LTX2VideoTransformer3DModel(
in_channels=4,
out_channels=4,
patch_size=1,
patch_size_t=1,
num_attention_heads=2,
attention_head_dim=8,
cross_attention_dim=16,
audio_in_channels=4,
audio_out_channels=4,
audio_num_attention_heads=2,
audio_attention_head_dim=4,
audio_cross_attention_dim=8,
num_layers=2,
qk_norm="rms_norm_across_heads",
caption_channels=text_encoder.config.text_config.hidden_size,
rope_double_precision=False,
rope_type="split",
)
torch.manual_seed(0)
connectors = LTX2TextConnectors(
caption_channels=text_encoder.config.text_config.hidden_size,
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
video_connector_num_attention_heads=4,
video_connector_attention_head_dim=8,
video_connector_num_layers=1,
video_connector_num_learnable_registers=None,
audio_connector_num_attention_heads=4,
audio_connector_attention_head_dim=8,
audio_connector_num_layers=1,
audio_connector_num_learnable_registers=None,
connector_rope_base_seq_len=32,
rope_theta=10000.0,
rope_double_precision=False,
causal_temporal_positioning=False,
rope_type="split",
)
torch.manual_seed(0)
vae = AutoencoderKLLTX2Video(
in_channels=3,
out_channels=3,
latent_channels=4,
block_out_channels=(8,),
decoder_block_out_channels=(8,),
layers_per_block=(1,),
decoder_layers_per_block=(1, 1),
spatio_temporal_scaling=(True,),
decoder_spatio_temporal_scaling=(True,),
decoder_inject_noise=(False, False),
downsample_type=("spatial",),
upsample_residual=(False,),
upsample_factor=(1,),
timestep_conditioning=False,
patch_size=1,
patch_size_t=1,
encoder_causal=True,
decoder_causal=False,
)
vae.use_framewise_encoding = False
vae.use_framewise_decoding = False
torch.manual_seed(0)
audio_vae = AutoencoderKLLTX2Audio(
base_channels=4,
output_channels=2,
ch_mult=(1,),
num_res_blocks=1,
attn_resolutions=None,
in_channels=2,
resolution=32,
latent_channels=2,
norm_type="pixel",
causality_axis="height",
dropout=0.0,
mid_block_add_attention=False,
sample_rate=16000,
mel_hop_length=160,
is_causal=True,
mel_bins=8,
)
torch.manual_seed(0)
vocoder = LTX2Vocoder(
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
hidden_channels=32,
out_channels=2,
upsample_kernel_sizes=[4, 4],
upsample_factors=[2, 2],
resnet_kernel_sizes=[3],
resnet_dilations=[[1, 3, 5]],
leaky_relu_negative_slope=0.1,
output_sampling_rate=16000,
)
scheduler = FlowMatchEulerDiscreteScheduler()
components = {
"transformer": transformer,
"vae": vae,
"audio_vae": audio_vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"image": image,
"prompt": "a robot dancing",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"num_frames": 5,
"frame_rate": 25.0,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = pipe(**inputs)
video = output.frames
audio = output.audio
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
self.assertEqual(audio.shape[0], 1)
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
# fmt: off
expected_video_slice = torch.tensor(
[
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
]
)
expected_audio_slice = torch.tensor(
[
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
]
)
# fmt: on
video = video.flatten()
audio = audio.flatten()
generated_video_slice = torch.cat([video[:8], video[-8:]])
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)