mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-11 06:05:38 +08:00
Compare commits
9 Commits
yiyi-test-
...
device-map
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b334de68a | ||
|
|
c10bdd9b73 | ||
|
|
dab000e88b | ||
|
|
9fb6b89d49 | ||
|
|
6fb4c99f5a | ||
|
|
961b9b27d3 | ||
|
|
c61e455ce7 | ||
|
|
6f5eb0a933 | ||
|
|
83ec2fb793 |
@@ -353,8 +353,6 @@
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -369,6 +367,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/longcat_image_transformer2d
|
||||
title: LongCatImageTransformer2DModel
|
||||
- local: api/models/ltx2_video_transformer3d
|
||||
title: LTX2VideoTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
@@ -445,6 +445,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
title: AutoencoderKLLTX2Video
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -543,8 +547,6 @@
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/glm_image
|
||||
title: GLM-Image
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
@@ -682,6 +684,8 @@
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx2
|
||||
title: LTX-2
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
|
||||
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Audio
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Audio
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Audio
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Video
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Video
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Video
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# GlmImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel]()
|
||||
|
||||
## GlmImageTransformer2DModel
|
||||
|
||||
[[autodoc]] GlmImageTransformer2DModel
|
||||
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
@@ -0,0 +1,26 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTX2VideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTX2VideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTX2VideoTransformer3DModel
|
||||
@@ -1,31 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# GLM-Image
|
||||
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The original codebase can be found [here](https://huggingface.co/zai-org). The original weights can be found under [hf.co/zai-org](https://huggingface.co/zai-org).
|
||||
|
||||
## GlmImagePipeline
|
||||
|
||||
[[autodoc]] GlmImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## GlmImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cogview4.pipeline_output.GlmImagePipelineOutput
|
||||
43
docs/source/en/api/pipelines/ltx2.md
Normal file
43
docs/source/en/api/pipelines/ltx2.md
Normal file
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTX2ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2LatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTX2LatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
|
||||
@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
886
scripts/convert_ltx2_to_diffusers.py
Normal file
886
scripts/convert_ltx2_to_diffusers.py
Normal file
@@ -0,0 +1,886 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upsample_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any(
|
||||
[
|
||||
args.vae,
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt,
|
||||
latent_upsampler_config,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -223,7 +225,6 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -237,6 +238,7 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
@@ -488,7 +490,6 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -540,6 +541,9 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
@@ -941,6 +945,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -971,7 +977,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -985,6 +990,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
@@ -1206,7 +1212,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -1258,6 +1263,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
|
||||
@@ -41,6 +41,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -96,7 +98,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -105,6 +106,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -154,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -204,7 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
@@ -214,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
self.up.insert(0, stage)
|
||||
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
if self.give_pre_end:
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE for encoding and decoding audio latent representations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
@@ -1658,37 +1658,6 @@ class CogView3CombinedTimestepSizeEmbeddings(nn.Module):
|
||||
return conditioning
|
||||
|
||||
|
||||
class GlmImageCombinedTimestepSizeEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim: int, condition_dim: int, pooled_projection_dim: int, timesteps_dim: int = 256):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=timesteps_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.condition_proj = Timesteps(num_channels=condition_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=timesteps_dim, time_embed_dim=embedding_dim)
|
||||
self.condition_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
hidden_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
|
||||
crop_coords_proj = self.condition_proj(crop_coords.flatten()).view(crop_coords.size(0), -1)
|
||||
target_size_proj = self.condition_proj(target_size.flatten()).view(target_size.size(0), -1)
|
||||
|
||||
# (B, 2 * condition_dim)
|
||||
condition_proj = torch.cat([crop_coords_proj, target_size_proj], dim=1)
|
||||
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
condition_emb = self.condition_embedder(condition_proj.to(dtype=hidden_dtype)) # (B, embedding_dim)
|
||||
|
||||
conditioning = timesteps_emb + condition_emb
|
||||
return conditioning
|
||||
|
||||
|
||||
class HunyuanDiTAttentionPool(nn.Module):
|
||||
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ if is_torch_available():
|
||||
from .transformer_easyanimate import EasyAnimateTransformer3DModel
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
@@ -36,6 +35,7 @@ if is_torch_available():
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
|
||||
@@ -1,567 +0,0 @@
|
||||
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import FeedForward
|
||||
from ..attention_processor import Attention
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import GlmImageCombinedTimestepSizeEmbeddings
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import LayerNorm, RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class GlmImageImageProjector(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
hidden_size: int = 2560,
|
||||
patch_size: int = 2,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Linear(in_channels * patch_size**2, hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
post_patch_height = height // self.patch_size
|
||||
post_patch_width = width // self.patch_size
|
||||
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channel, post_patch_height, self.patch_size, post_patch_width, self.patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
||||
hidden_states = self.proj(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormZero(nn.Module):
|
||||
def __init__(self, embedding_dim: int, dim: int) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = hidden_states.dtype
|
||||
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
||||
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
||||
|
||||
emb = self.linear(temb)
|
||||
(
|
||||
shift_msa,
|
||||
c_shift_msa,
|
||||
scale_msa,
|
||||
c_scale_msa,
|
||||
gate_msa,
|
||||
c_gate_msa,
|
||||
shift_mlp,
|
||||
c_shift_mlp,
|
||||
scale_mlp,
|
||||
c_scale_mlp,
|
||||
gate_mlp,
|
||||
c_gate_mlp,
|
||||
) = emb.chunk(12, dim=1)
|
||||
|
||||
hidden_states = norm_hidden_states * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
|
||||
encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_msa.unsqueeze(1)) + c_shift_msa.unsqueeze(1)
|
||||
|
||||
return (
|
||||
hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
)
|
||||
|
||||
|
||||
class GlmImageLayerKVCache:
|
||||
"""KV cache for GlmImage model."""
|
||||
def __init__(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode: Optional[str] = None # "write", "read", "skip"
|
||||
|
||||
def store(self, k: torch.Tensor, v: torch.Tensor):
|
||||
if self.k_cache is None:
|
||||
self.k_cache = k
|
||||
self.v_cache = v
|
||||
else:
|
||||
self.k_cache = torch.cat([self.k_cache, k], dim=2)
|
||||
self.v_cache = torch.cat([self.v_cache, v], dim=2)
|
||||
|
||||
def get(self):
|
||||
return self.k_cache, self.v_cache
|
||||
|
||||
def clear(self):
|
||||
self.k_cache = None
|
||||
self.v_cache = None
|
||||
self.mode = None
|
||||
|
||||
|
||||
class GlmImageKVCache:
|
||||
"""Container for all layers' KV caches."""
|
||||
|
||||
def __init__(self, num_layers: int):
|
||||
self.num_layers = num_layers
|
||||
self.caches = [GlmImageLayerKVCache() for _ in range(num_layers)]
|
||||
|
||||
def __getitem__(self, layer_idx: int) -> GlmImageLayerKVCache:
|
||||
return self.caches[layer_idx]
|
||||
|
||||
def set_mode(self, mode: Optional[str]):
|
||||
if mode is not None and mode not in ["write", "read", "skip"]:
|
||||
raise ValueError(f"Invalid mode: {mode}, must be one of 'write', 'read', 'skip'")
|
||||
for cache in self.caches:
|
||||
cache.mode = mode
|
||||
|
||||
def clear(self):
|
||||
for cache in self.caches:
|
||||
cache.clear()
|
||||
|
||||
class GlmImageAttnProcessor:
|
||||
"""
|
||||
Processor for implementing scaled dot-product attention for the GlmImage model. It applies a rotary embedding on
|
||||
query and key vectors, but does not include spatial normalization.
|
||||
|
||||
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
||||
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("GlmImageAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
dtype = encoder_hidden_states.dtype
|
||||
|
||||
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
||||
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 1. QKV projections
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
||||
|
||||
# 2. QK normalization
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query).to(dtype=dtype)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key).to(dtype=dtype)
|
||||
|
||||
# 3. Rotational positional embeddings applied to latent stream
|
||||
if image_rotary_emb is not None:
|
||||
from ..embeddings import apply_rotary_emb
|
||||
|
||||
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
||||
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
||||
)
|
||||
|
||||
if kv_cache is not None:
|
||||
if kv_cache.mode == "write":
|
||||
kv_cache.store(key, value)
|
||||
elif kv_cache.mode == "read":
|
||||
k_cache, v_cache = kv_cache.get()
|
||||
key = torch.cat([k_cache, key], dim=2) if k_cache is not None else key
|
||||
value = torch.cat([v_cache, value], dim=2) if v_cache is not None else value
|
||||
elif kv_cache.mode == "skip":
|
||||
pass
|
||||
|
||||
# 4. Attention
|
||||
if attention_mask is not None:
|
||||
text_attn_mask = attention_mask
|
||||
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
||||
text_attn_mask = text_attn_mask.float().to(query.device)
|
||||
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
||||
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
||||
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
||||
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
||||
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
||||
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
# 5. Output projection
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
encoder_hidden_states, hidden_states = hidden_states.split(
|
||||
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
||||
)
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class GlmImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int = 2560,
|
||||
num_attention_heads: int = 64,
|
||||
attention_head_dim: int = 40,
|
||||
time_embed_dim: int = 512,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
# 1. Attention
|
||||
self.norm1 = GlmImageAdaLayerNormZero(time_embed_dim, dim)
|
||||
self.attn1 = Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
out_dim=dim,
|
||||
bias=True,
|
||||
qk_norm="layer_norm",
|
||||
elementwise_affine=False,
|
||||
eps=1e-5,
|
||||
processor=GlmImageAttnProcessor(),
|
||||
)
|
||||
|
||||
# 2. Feedforward
|
||||
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
||||
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
kv_cache: Optional[GlmImageLayerKVCache] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Timestep conditioning
|
||||
(
|
||||
norm_hidden_states,
|
||||
gate_msa,
|
||||
shift_mlp,
|
||||
scale_mlp,
|
||||
gate_mlp,
|
||||
norm_encoder_hidden_states,
|
||||
c_gate_msa,
|
||||
c_shift_mlp,
|
||||
c_scale_mlp,
|
||||
c_gate_mlp,
|
||||
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
||||
|
||||
# 2. Attention
|
||||
if attention_kwargs is None:
|
||||
attention_kwargs = {}
|
||||
|
||||
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
attention_mask=attention_mask,
|
||||
kv_cache=kv_cache,
|
||||
**attention_kwargs,
|
||||
)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
||||
|
||||
# 3. Feedforward
|
||||
norm_hidden_states = self.norm2(hidden_states) * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
|
||||
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) * (
|
||||
1 + c_scale_mlp.unsqueeze(1)
|
||||
) + c_shift_mlp.unsqueeze(1)
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
ff_output_context = self.ff(norm_encoder_hidden_states)
|
||||
hidden_states = hidden_states + ff_output * gate_mlp.unsqueeze(1)
|
||||
encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp.unsqueeze(1)
|
||||
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
class GlmImageRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, dim: int, patch_size: int, theta: float = 10000.0) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.theta = theta
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
height, width = height // self.patch_size, width // self.patch_size
|
||||
|
||||
dim_h, dim_w = self.dim // 2, self.dim // 2
|
||||
h_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
|
||||
)
|
||||
w_inv_freq = 1.0 / (
|
||||
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
|
||||
)
|
||||
h_seq = torch.arange(height)
|
||||
w_seq = torch.arange(width)
|
||||
freqs_h = torch.outer(h_seq, h_inv_freq)
|
||||
freqs_w = torch.outer(w_seq, w_inv_freq)
|
||||
|
||||
# Create position matrices for height and width
|
||||
# [height, 1, dim//4] and [1, width, dim//4]
|
||||
freqs_h = freqs_h.unsqueeze(1)
|
||||
freqs_w = freqs_w.unsqueeze(0)
|
||||
# Broadcast freqs_h and freqs_w to [height, width, dim//4]
|
||||
freqs_h = freqs_h.expand(height, width, -1)
|
||||
freqs_w = freqs_w.expand(height, width, -1)
|
||||
|
||||
# Concatenate along last dimension to get [height, width, dim//2]
|
||||
freqs = torch.cat([freqs_h, freqs_w], dim=-1)
|
||||
freqs = torch.cat([freqs, freqs], dim=-1) # [height, width, dim]
|
||||
freqs = freqs.reshape(height * width, -1)
|
||||
return (freqs.cos(), freqs.sin())
|
||||
|
||||
|
||||
class GlmImageAdaLayerNormContinuous(nn.Module):
|
||||
"""
|
||||
GlmImage-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
||||
Linear on conditioning embedding.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedding_dim: int,
|
||||
conditioning_embedding_dim: int,
|
||||
elementwise_affine: bool = True,
|
||||
eps: float = 1e-5,
|
||||
bias: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
):
|
||||
super().__init__()
|
||||
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
||||
if norm_type == "layer_norm":
|
||||
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
||||
elif norm_type == "rms_norm":
|
||||
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
||||
else:
|
||||
raise ValueError(f"unknown norm_type {norm_type}")
|
||||
|
||||
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
||||
# *** NO SiLU here ***
|
||||
emb = self.linear(conditioning_embedding.to(x.dtype))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
||||
r"""
|
||||
Args:
|
||||
patch_size (`int`, defaults to `2`):
|
||||
The size of the patches to use in the patch embedding layer.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
num_layers (`int`, defaults to `30`):
|
||||
The number of layers of Transformer blocks to use.
|
||||
attention_head_dim (`int`, defaults to `40`):
|
||||
The number of channels in each head.
|
||||
num_attention_heads (`int`, defaults to `64`):
|
||||
The number of heads to use for multi-head attention.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_embed_dim (`int`, defaults to `1472`):
|
||||
Input dimension of text embeddings from the text encoder.
|
||||
time_embed_dim (`int`, defaults to `512`):
|
||||
Output dimension of timestep embeddings.
|
||||
condition_dim (`int`, defaults to `256`):
|
||||
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
||||
crop_coords).
|
||||
pos_embed_max_size (`int`, defaults to `128`):
|
||||
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
||||
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
||||
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
||||
patch_size => 128 * 8 * 2 => 2048`.
|
||||
sample_size (`int`, defaults to `128`):
|
||||
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
||||
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_no_split_modules = [
|
||||
"GlmImageTransformerBlock",
|
||||
"GlmImageImageProjector",
|
||||
"GlmImageImageProjector",
|
||||
]
|
||||
_skip_layerwise_casting_patterns = ["patch_embed", "norm", "proj_out"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
num_layers: int = 30,
|
||||
attention_head_dim: int = 40,
|
||||
num_attention_heads: int = 64,
|
||||
text_embed_dim: int = 1472,
|
||||
time_embed_dim: int = 512,
|
||||
condition_dim: int = 256,
|
||||
prior_vq_quantizer_codebook_size: int = 16384,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# GlmImage uses 2 additional SDXL-like conditions - target_size, crop_coords
|
||||
# Each of these are sincos embeddings of shape 2 * condition_dim
|
||||
pooled_projection_dim = 2 * 2 * condition_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels
|
||||
|
||||
# 1. RoPE
|
||||
self.rope = GlmImageRotaryPosEmbed(attention_head_dim, patch_size, theta=10000.0)
|
||||
|
||||
# 2. Patch & Text-timestep embedding
|
||||
self.image_projector = GlmImageImageProjector(in_channels, inner_dim, patch_size)
|
||||
self.glyph_projector = FeedForward(text_embed_dim, inner_dim, inner_dim=inner_dim, activation_fn="gelu")
|
||||
self.prior_token_embedding = nn.Embedding(prior_vq_quantizer_codebook_size, inner_dim)
|
||||
self.prior_projector = FeedForward(inner_dim, inner_dim, inner_dim=inner_dim, activation_fn="linear-silu")
|
||||
|
||||
self.time_condition_embed = GlmImageCombinedTimestepSizeEmbeddings(
|
||||
embedding_dim=time_embed_dim,
|
||||
condition_dim=condition_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
timesteps_dim=time_embed_dim,
|
||||
)
|
||||
|
||||
# 3. Transformer blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
GlmImageTransformerBlock(inner_dim, num_attention_heads, attention_head_dim, time_embed_dim)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Output projection
|
||||
self.norm_out = GlmImageAdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
prior_token_id: torch.Tensor,
|
||||
prior_token_drop: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
target_size: torch.Tensor,
|
||||
crop_coords: torch.Tensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
kv_caches: Optional[GlmImageKVCache] = None,
|
||||
image_rotary_emb: Optional[
|
||||
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
||||
] = None,
|
||||
) -> Union[Tuple[torch.Tensor], Transformer2DModelOutput]:
|
||||
batch_size, num_channels, height, width = hidden_states.shape
|
||||
|
||||
# 1. RoPE
|
||||
if image_rotary_emb is None:
|
||||
image_rotary_emb = self.rope(hidden_states)
|
||||
|
||||
# 2. Patch & Timestep embeddings
|
||||
p = self.config.patch_size
|
||||
post_patch_height = height // p
|
||||
post_patch_width = width // p
|
||||
|
||||
hidden_states = self.image_projector(hidden_states)
|
||||
encoder_hidden_states = self.glyph_projector(encoder_hidden_states)
|
||||
prior_embedding = self.prior_token_embedding(prior_token_id)
|
||||
prior_embedding[prior_token_drop] *= 0.0
|
||||
prior_hidden_states = self.prior_projector(prior_embedding)
|
||||
|
||||
hidden_states = hidden_states + prior_hidden_states
|
||||
|
||||
temb = self.time_condition_embed(timestep, target_size, crop_coords, hidden_states.dtype)
|
||||
temb = F.silu(temb)
|
||||
|
||||
# 3. Transformer blocks
|
||||
for idx, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
else:
|
||||
hidden_states, encoder_hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
attention_mask,
|
||||
attention_kwargs,
|
||||
kv_cache=kv_caches[idx] if kv_caches is not None else None,
|
||||
)
|
||||
|
||||
# 4. Output norm & projection
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 5. Unpatchify
|
||||
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, -1, p, p)
|
||||
output = hidden_states.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -168,6 +168,14 @@ class MellonParam:
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
"""
|
||||
|
||||
@@ -290,6 +290,7 @@ else:
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -737,6 +738,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -52,7 +52,6 @@ from .flux import (
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
)
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -168,7 +167,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("chroma", ChromaPipeline),
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -1,882 +0,0 @@
|
||||
# Copyright 2025 The CogVideoX team, Tsinghua University & ZhipuAI and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from math import sqrt
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import ByT5Tokenizer, GlmImageForConditionalGeneration, GlmImageProcessor, T5EncoderModel
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import CogView4LoraLoaderMixin
|
||||
from ...models import AutoencoderKL, GlmImageTransformer2DModel
|
||||
from ...models.transformers.transformer_glm_image import GlmImageKVCache
|
||||
from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from .pipeline_output import GlmImagePipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from diffusers import GlmImagePipeline
|
||||
|
||||
>>> pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A photo of an astronaut riding a horse on mars<sop>36 24<eop>"
|
||||
>>> image = pipe(prompt).images[0]
|
||||
>>> image.save("output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
base_shift: float = 0.25,
|
||||
max_shift: float = 0.75,
|
||||
) -> float:
|
||||
m = (image_seq_len / base_seq_len) ** 0.5
|
||||
mu = m * max_shift + base_shift
|
||||
return mu
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
"""
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
accepts_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
|
||||
if timesteps is not None and sigmas is not None:
|
||||
if not accepts_timesteps and not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep or sigma schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is not None and sigmas is None:
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif timesteps is None and sigmas is not None:
|
||||
if not accepts_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class GlmImagePipeline(DiffusionPipeline, CogView4LoraLoaderMixin):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using GLM-Image.
|
||||
|
||||
This pipeline integrates both the AR (autoregressive) model for token generation and the DiT (diffusion
|
||||
transformer) model for image decoding.
|
||||
|
||||
Args:
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`T5EncoderModel`]):
|
||||
Frozen text-encoder for glyph embeddings.
|
||||
tokenizer (`PreTrainedTokenizer`):
|
||||
Tokenizer for the text encoder.
|
||||
processor (`AutoProcessor`):
|
||||
Processor for the AR model to handle chat templates and tokenization.
|
||||
vision_language_encoder ([`GlmImageForConditionalGeneration`]):
|
||||
The AR model that generates image tokens from text prompts.
|
||||
transformer ([`GlmImageTransformer2DModel`]):
|
||||
A text conditioned transformer to denoise the encoded image latents (DiT).
|
||||
scheduler ([`SchedulerMixin`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
_optional_components = []
|
||||
model_cpu_offload_seq = "transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: ByT5Tokenizer,
|
||||
processor: GlmImageProcessor,
|
||||
text_encoder: T5EncoderModel,
|
||||
vision_language_encoder: GlmImageForConditionalGeneration,
|
||||
vae: AutoencoderKL,
|
||||
transformer: GlmImageTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
tokenizer=tokenizer,
|
||||
processor=processor,
|
||||
text_encoder=text_encoder,
|
||||
vision_language_encoder=vision_language_encoder,
|
||||
vae=vae,
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
|
||||
self.default_sample_size = (
|
||||
self.transformer.config.sample_size
|
||||
if hasattr(self, "transformer")
|
||||
and self.transformer is not None
|
||||
and hasattr(self.transformer.config, "sample_size")
|
||||
else 128
|
||||
)
|
||||
|
||||
def _build_image_grid_thw(
|
||||
self,
|
||||
token_h: int,
|
||||
token_w: int,
|
||||
prev_token_h: int,
|
||||
prev_token_w: int,
|
||||
existing_grid: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> torch.Tensor:
|
||||
if existing_grid is None or existing_grid.numel() == 0:
|
||||
return torch.tensor(
|
||||
[
|
||||
[1, token_h, token_w],
|
||||
[1, prev_token_h, prev_token_w],
|
||||
],
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
return torch.cat([existing_grid.to(device), torch.tensor([[1, token_h, token_w]], device=device)], dim=0)
|
||||
|
||||
def _calculate_ar_generation_params(
|
||||
self, token_h: int, token_w: int, prev_token_h: int, prev_token_w: int, is_text_to_image: bool
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Calculate max_new_tokens and large_image_start_offset for AR generation.
|
||||
"""
|
||||
large_image_tokens = token_h * token_w
|
||||
small_image_tokens = prev_token_h * prev_token_w
|
||||
|
||||
if is_text_to_image:
|
||||
max_new_tokens = small_image_tokens + large_image_tokens + 1
|
||||
large_image_start_offset = small_image_tokens
|
||||
else:
|
||||
max_new_tokens = large_image_tokens + 1
|
||||
large_image_start_offset = 0
|
||||
|
||||
return max_new_tokens, large_image_start_offset
|
||||
|
||||
def _extract_large_image_tokens(
|
||||
self, outputs: torch.Tensor, input_length: int, large_image_start_offset: int, large_image_tokens: int
|
||||
) -> torch.Tensor:
|
||||
generated_tokens = outputs[0][input_length:]
|
||||
large_image_start = large_image_start_offset
|
||||
large_image_end = large_image_start + large_image_tokens
|
||||
return generated_tokens[large_image_start:large_image_end]
|
||||
|
||||
def _upsample_d32_to_d16(self, token_ids: torch.Tensor, token_h: int, token_w: int) -> torch.Tensor:
|
||||
"""
|
||||
Upsample token IDs from d32 format to d16 format.
|
||||
|
||||
AR model generates tokens at d32 resolution (each token = 32x32 pixels). DiT expects tokens at d16 resolution
|
||||
(each token = 16x16 pixels). This function performs 2x nearest-neighbor upsampling.
|
||||
|
||||
Args:
|
||||
token_ids: Token IDs of shape [N] where N = token_h * token_w
|
||||
token_h: Height in d32 token units
|
||||
token_w: Width in d32 token units
|
||||
|
||||
Returns:
|
||||
Upsampled token IDs of shape [1, N*4] where N*4 = (token_h*2) * (token_w*2)
|
||||
"""
|
||||
token_ids = token_ids.view(1, 1, token_h, token_w)
|
||||
token_ids = torch.nn.functional.interpolate(token_ids.float(), scale_factor=2, mode="nearest").to(
|
||||
dtype=torch.long
|
||||
)
|
||||
|
||||
token_ids = token_ids.view(1, -1)
|
||||
return token_ids
|
||||
|
||||
def _build_prompt_with_shape(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
is_text_to_image: bool,
|
||||
factor: int = 32,
|
||||
) -> Tuple[str, int, int, int, int]:
|
||||
"""
|
||||
Build prompt with shape info (<sop>H W<eop>) based on height and width.
|
||||
|
||||
Args:
|
||||
prompt: The raw text prompt without shape info
|
||||
height: Target image height in pixels
|
||||
width: Target image width in pixels
|
||||
is_text_to_image: Whether this is text-to-image (True) or image-to-image (False)
|
||||
|
||||
Returns:
|
||||
Tuple of (expanded_prompt, token_h, token_w, prev_token_h, prev_token_w)
|
||||
"""
|
||||
token_h = height // factor
|
||||
token_w = width // factor
|
||||
ratio = token_h / token_w
|
||||
prev_token_h = int(sqrt(ratio) * (factor // 2))
|
||||
prev_token_w = int(sqrt(1 / ratio) * (factor // 2))
|
||||
|
||||
if is_text_to_image:
|
||||
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop><sop>{prev_token_h} {prev_token_w}<eop>"
|
||||
else:
|
||||
expanded_prompt = f"{prompt}<sop>{token_h} {token_w}<eop>"
|
||||
|
||||
return expanded_prompt, token_h, token_w, prev_token_h, prev_token_w
|
||||
|
||||
def generate_prior_tokens(
|
||||
self,
|
||||
prompt: str,
|
||||
height: int,
|
||||
width: int,
|
||||
image: Optional[List[PIL.Image.Image]] = None,
|
||||
factor: int = 32,
|
||||
) -> Tuple[torch.Tensor, int, int]:
|
||||
"""
|
||||
Generate prior tokens using the AR (vision_language_encoder) model.
|
||||
|
||||
Automatically builds the prompt with shape info based on height/width. Users only need to provide the raw text
|
||||
prompt without <sop>...<eop> tags.
|
||||
|
||||
Args:
|
||||
prompt: The raw text prompt (without shape info)
|
||||
height: Target image height in pixels (must be divisible by factor)
|
||||
width: Target image width in pixels (must be divisible by factor)
|
||||
image: Optional list of condition images for image-to-image generation
|
||||
factor: Token size factor (32 for d32 tokens)
|
||||
|
||||
Returns:
|
||||
Tuple of (prior_token_ids, pixel_height, pixel_width)
|
||||
- prior_token_ids: Upsampled to d16 format, shape [1, token_h*token_w*4]
|
||||
- pixel_height: Image height in pixels (aligned to factor)
|
||||
- pixel_width: Image width in pixels (aligned to factor)
|
||||
|
||||
"""
|
||||
device = self.vision_language_encoder.device
|
||||
height = (height // factor) * factor
|
||||
width = (width // factor) * factor
|
||||
is_text_to_image = image is None or len(image) == 0
|
||||
expanded_prompt, token_h, token_w, prev_h, prev_w = self._build_prompt_with_shape(
|
||||
prompt, height, width, is_text_to_image
|
||||
)
|
||||
content = []
|
||||
if image is not None:
|
||||
for img in image:
|
||||
content.append({"type": "image", "image": img})
|
||||
content.append({"type": "text", "text": expanded_prompt})
|
||||
messages = [{"role": "user", "content": content}]
|
||||
inputs = self.processor.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
existing_grid = inputs.get("image_grid_thw")
|
||||
inputs["image_grid_thw"] = self._build_image_grid_thw(
|
||||
token_h,
|
||||
token_w,
|
||||
prev_h,
|
||||
prev_w,
|
||||
existing_grid=existing_grid if not is_text_to_image else None,
|
||||
device=device,
|
||||
)
|
||||
|
||||
max_new_tokens, large_image_offset = self._calculate_ar_generation_params(
|
||||
token_h, token_w, prev_h, prev_w, is_text_to_image
|
||||
)
|
||||
large_image_tokens = token_h * token_w
|
||||
|
||||
inputs = inputs.to(device)
|
||||
input_length = inputs["input_ids"].shape[-1]
|
||||
|
||||
outputs = self.vision_language_encoder.generate(
|
||||
**inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
|
||||
prior_token_ids_d32 = self._extract_large_image_tokens(
|
||||
outputs, input_length, large_image_offset, large_image_tokens
|
||||
)
|
||||
prior_token_ids = self._upsample_d32_to_d16(prior_token_ids_d32, token_h, token_w)
|
||||
|
||||
pixel_height = token_h * factor
|
||||
pixel_width = token_w * factor
|
||||
|
||||
return prior_token_ids, pixel_height, pixel_width
|
||||
|
||||
def get_glyph_texts(self, prompt):
|
||||
prompt = prompt[0] if isinstance(prompt, list) else prompt
|
||||
ocr_texts = (
|
||||
re.findall(r"'([^']*)'", prompt)
|
||||
+ re.findall(r"“([^“”]*)”", prompt)
|
||||
+ re.findall(r'"([^"]*)"', prompt)
|
||||
+ re.findall(r"「([^「」]*)」", prompt)
|
||||
)
|
||||
return ocr_texts
|
||||
|
||||
def _get_glyph_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
dtype = dtype or self.text_encoder.dtype
|
||||
|
||||
glyph_texts = self.get_glyph_texts(prompt)
|
||||
input_ids = self.tokenizer(
|
||||
glyph_texts if len(glyph_texts) > 0 else [""],
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
).input_ids
|
||||
input_ids = [
|
||||
[self.tokenizer.pad_token_id] * ((len(input_ids) + 1) % 2) + input_ids_ for input_ids_ in input_ids
|
||||
]
|
||||
max_length = max(len(input_ids_) for input_ids_ in input_ids)
|
||||
attention_mask = torch.tensor(
|
||||
[[1] * len(input_ids_) + [0] * (max_length - len(input_ids_)) for input_ids_ in input_ids], device=device
|
||||
)
|
||||
input_ids = torch.tensor(
|
||||
[input_ids_ + [self.tokenizer.pad_token_id] * (max_length - len(input_ids_)) for input_ids_ in input_ids],
|
||||
device=device,
|
||||
)
|
||||
outputs = self.text_encoder(input_ids, attention_mask=attention_mask)
|
||||
glyph_embeds = outputs.last_hidden_state[attention_mask.bool()].unsqueeze(0)
|
||||
|
||||
return glyph_embeds.to(device=device, dtype=dtype)
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
max_sequence_length: int = 2048,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use classifier free guidance or not.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
Number of images that should be generated per prompt. torch device to place the resulting embeddings on
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device
|
||||
dtype: (`torch.dtype`, *optional*):
|
||||
torch dtype
|
||||
max_sequence_length (`int`, defaults to `2048`):
|
||||
Maximum sequence length in encoded prompt. Can be set to other values but may lead to poorer results.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_glyph_embeds(prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = prompt_embeds.size(1)
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
if do_classifier_free_guidance:
|
||||
negative_prompt = ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds = self._get_glyph_embeds(negative_prompt, max_sequence_length, device, dtype)
|
||||
|
||||
seq_len = negative_prompt_embeds.size(1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device)
|
||||
|
||||
shape = (
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
int(height) // self.vae_scale_factor,
|
||||
int(width) // self.vae_scale_factor,
|
||||
)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
return latents
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * self.transformer.config.patch_size) != 0
|
||||
or width is not None
|
||||
and width % (self.transformer.config.patch_size) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
image: Optional[
|
||||
Union[
|
||||
torch.Tensor, PIL.Image.Image, np.ndarray, List[torch.Tensor], List[PIL.Image.Image], List[np.ndarray]
|
||||
]
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 1.5,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
crops_coords_top_left: Tuple[int, int] = (0, 0),
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 2048,
|
||||
) -> Union[GlmImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. Must contain shape info in the format '<sop>H
|
||||
W<eop>' where H and W are token dimensions (d32). Example: "A beautiful sunset<sop>36 24<eop>"
|
||||
generates a 1152x768 image.
|
||||
image: Optional condition images for image-to-image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels. If not provided, derived from prompt shape info.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels. If not provided, derived from prompt shape info.
|
||||
num_inference_steps (`int`, *optional*, defaults to `50`):
|
||||
The number of denoising steps for DiT.
|
||||
guidance_scale (`float`, *optional*, defaults to `1.5`):
|
||||
Guidance scale for classifier-free guidance.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to `1`):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
Random generator for reproducibility.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: "pil", "np", or "latent".
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`GlmImagePipelineOutput`] or `tuple`: Generated images.
|
||||
"""
|
||||
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
assert batch_size == 1, "batch_size must be 1"
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
ar_condition_images = None
|
||||
if image is not None:
|
||||
if not isinstance(image, list):
|
||||
image = [image]
|
||||
ar_condition_images = []
|
||||
for img in image:
|
||||
if isinstance(img, PIL.Image.Image):
|
||||
ar_condition_images.append(img)
|
||||
elif isinstance(img, torch.Tensor):
|
||||
img_np = img.cpu().numpy()
|
||||
if img_np.ndim == 4:
|
||||
img_np = img_np[0]
|
||||
if img_np.shape[0] in [1, 3, 4]:
|
||||
img_np = img_np.transpose(1, 2, 0)
|
||||
if img_np.max() <= 1.0:
|
||||
img_np = (img_np * 255).astype(np.uint8)
|
||||
ar_condition_images.append(PIL.Image.fromarray(img_np))
|
||||
else:
|
||||
ar_condition_images.append(PIL.Image.fromarray(img))
|
||||
|
||||
prior_token_id, ar_height, ar_width = self.generate_prior_tokens(
|
||||
prompt=prompt[0] if isinstance(prompt, list) else prompt,
|
||||
image=ar_condition_images,
|
||||
height=height,
|
||||
width=width,
|
||||
)
|
||||
|
||||
height = height or ar_height
|
||||
width = width or ar_width
|
||||
|
||||
# 3. Encode input prompt
|
||||
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
|
||||
prompt,
|
||||
self.do_classifier_free_guidance,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
condition_images_prior_token_id = None
|
||||
if image is not None:
|
||||
preprocessed_condition_images = []
|
||||
condition_images_prior_token_id = []
|
||||
for img in image:
|
||||
image_height, image_width = img.size[::-1] if isinstance(img, PIL.Image.Image) else img.shape[:2]
|
||||
multiple_of = self.vae_scale_factor * self.transformer.config.patch_size
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width)
|
||||
preprocessed_condition_images.append(img)
|
||||
image = preprocessed_condition_images
|
||||
|
||||
# 5. Prepare latents and (optional) image kv cache
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_channels_latents=latent_channels,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
kv_caches = GlmImageKVCache(num_layers=self.transformer.config.num_layers)
|
||||
|
||||
if image is not None and condition_images_prior_token_id is not None:
|
||||
kv_caches.set_mode("write")
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(self.vae.device, self.vae.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(self.vae.device, self.vae.dtype)
|
||||
)
|
||||
empty_glyph_hiddens = torch.zeros_like(prompt_embeds)[:1, :0, ...]
|
||||
for condition_image, condition_image_prior_token_id in zip(image, condition_images_prior_token_id):
|
||||
condition_image = condition_image.to(device=device, dtype=self.vae.dtype)
|
||||
condition_latent = retrieve_latents(
|
||||
self.vae.encode(condition_image), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = (condition_latent - latents_mean) / latents_std
|
||||
_ = self.transformer(
|
||||
hidden_states=condition_latent,
|
||||
encoder_hidden_states=empty_glyph_hiddens,
|
||||
prior_token_id=condition_image_prior_token_id,
|
||||
prior_token_drop=torch.full_like(condition_image_prior_token_id, False, dtype=torch.bool),
|
||||
timestep=torch.zeros((1,), device=device),
|
||||
target_size=torch.tensor([condition_image.shape[-2:]], device=device),
|
||||
crop_coords=torch.zeros((1, 2), device=device),
|
||||
attention_kwargs=attention_kwargs,
|
||||
kv_caches=kv_caches,
|
||||
)
|
||||
|
||||
# 6. Prepare additional timestep conditions
|
||||
target_size = (height, width)
|
||||
target_size = torch.tensor([target_size], dtype=prompt_embeds.dtype, device=device)
|
||||
crops_coords_top_left = torch.tensor([crops_coords_top_left], dtype=prompt_embeds.dtype, device=device)
|
||||
|
||||
target_size = target_size.repeat(batch_size * num_images_per_prompt, 1)
|
||||
crops_coords_top_left = crops_coords_top_left.repeat(batch_size * num_images_per_prompt, 1)
|
||||
|
||||
# Prepare timesteps
|
||||
image_seq_len = ((height // self.vae_scale_factor) * (width // self.vae_scale_factor)) // (
|
||||
self.transformer.config.patch_size**2
|
||||
)
|
||||
timesteps = (
|
||||
np.linspace(self.scheduler.config.num_train_timesteps, 1.0, num_inference_steps + 1)[:-1]
|
||||
if timesteps is None
|
||||
else np.array(timesteps)
|
||||
)
|
||||
timesteps = timesteps.astype(np.int64).astype(np.float32)
|
||||
sigmas = timesteps / self.scheduler.config.num_train_timesteps if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("base_shift", 0.25),
|
||||
self.scheduler.config.get("max_shift", 0.75),
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas, mu=mu
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop
|
||||
transformer_dtype = self.transformer.dtype
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
prior_token_drop_cond = torch.full_like(prior_token_id, False, dtype=torch.bool)
|
||||
prior_token_drop_uncond = torch.full_like(prior_token_id, True, dtype=torch.bool)
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
|
||||
timestep = t.expand(latents.shape[0]) - 1
|
||||
|
||||
if image is not None:
|
||||
kv_caches.set_mode("read")
|
||||
|
||||
noise_pred_cond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
prior_token_id=prior_token_id,
|
||||
prior_token_drop=prior_token_drop_cond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
# perform guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
if image is not None:
|
||||
kv_caches.set_mode("skip")
|
||||
noise_pred_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
prior_token_id=prior_token_id,
|
||||
prior_token_drop=prior_token_drop_uncond,
|
||||
timestep=timestep,
|
||||
target_size=target_size,
|
||||
crop_coords=crops_coords_top_left,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_caches=kv_caches,
|
||||
)[0].float()
|
||||
|
||||
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
||||
else:
|
||||
noise_pred = noise_pred_cond
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, self.scheduler.sigmas[i], callback_kwargs)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
self._current_timestep = None
|
||||
kv_caches.clear()
|
||||
|
||||
if not output_type == "latent":
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = (
|
||||
torch.tensor(self.vae.config.latents_std)
|
||||
.view(1, self.vae.config.latent_channels, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents = latents * latents_std + latents_mean
|
||||
image = self.vae.decode(latents, return_dict=False, generator=generator)[0]
|
||||
else:
|
||||
image = latents
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return GlmImagePipelineOutput(images=image)
|
||||
@@ -1,21 +0,0 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class GlmImagePipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for CogView3 pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
@@ -11,8 +11,8 @@ from ...utils import (
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["GlmImagePipelineOutput"]}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -22,15 +22,28 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_glm_image"] = ["GlmImagePipeline"]
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_glm_image import GlmImagePipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -43,5 +56,3 @@ else:
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
num_attention_heads: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
if rope_type not in ["interleaved", "split"]:
|
||||
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
|
||||
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
self.rope_type = rope_type
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
if self.rope_type == "interleaved":
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
elif self.rope_type == "split":
|
||||
expected_freqs = self.dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
b = cos_freq.shape[0]
|
||||
t = cos_freq.shape[1]
|
||||
|
||||
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
|
||||
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim,
|
||||
base_seq_len=rope_base_seq_len,
|
||||
theta=rope_theta,
|
||||
double_precision=rope_double_precision,
|
||||
rope_type=rope_type,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_binarize_threshold: float = -9000.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_av_available
|
||||
|
||||
|
||||
_CAN_USE_AV = is_av_available()
|
||||
if _CAN_USE_AV:
|
||||
import av
|
||||
else:
|
||||
raise ImportError(
|
||||
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container,
|
||||
audio_stream: av.audio.AudioStream,
|
||||
samples: torch.Tensor,
|
||||
audio_sample_rate: int,
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
RATIONAL_RESAMPLER_SCALE_MAPPING = {
|
||||
0.75: (3, 4),
|
||||
1.5: (3, 2),
|
||||
2.0: (2, 1),
|
||||
4.0: (4, 1),
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.activation(hidden_states + residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
if dims not in [1, 2, 3]:
|
||||
raise ValueError("dims must be 1, 2, or 3")
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:3]))
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(6, 7)
|
||||
.flatten(4, 5)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 2:
|
||||
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 1:
|
||||
# temporal: b (c p1) f h w -> b c (f p1) h w
|
||||
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
|
||||
Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
|
||||
if dims not in (2, 3):
|
||||
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
|
||||
if kernel_size < 3 or kernel_size % 2 != 1:
|
||||
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
|
||||
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
c = x.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, _, _ = x.shape
|
||||
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
|
||||
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
|
||||
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
|
||||
(integer) denominator.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
|
||||
if num_denom is None:
|
||||
raise ValueError(
|
||||
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
|
||||
)
|
||||
self.num, self.den = num_denom
|
||||
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Expected x shape: [B * F, C, H, W]
|
||||
# b, _, f, h, w = x.shape
|
||||
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
Number of channels in the input latent
|
||||
mid_channels (`int`, defaults to `512`):
|
||||
Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`, defaults to `4`):
|
||||
Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`, defaults to `3`):
|
||||
Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`, defaults to `True`):
|
||||
Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`, defaults to `False`):
|
||||
Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: Optional[float] = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
|
||||
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_spatial_scale is not None:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.dims == 2:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
if self.temporal_upsample:
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states[:, :, 1:, :, :]
|
||||
else:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
File diff suppressed because it is too large
Load Diff
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTX2Video
|
||||
from ...utils import get_logger, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..ltx.pipeline_output import LTXPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
|
||||
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
>>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
|
||||
... )
|
||||
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> frame_rate = 24.0
|
||||
>>> video, audio = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... num_frames=121,
|
||||
... frame_rate=frame_rate,
|
||||
... num_inference_steps=40,
|
||||
... guidance_scale=4.0,
|
||||
... output_type="pil",
|
||||
... return_dict=False,
|
||||
... )
|
||||
|
||||
>>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
>>> upsample_pipe.vae.enable_tiling()
|
||||
>>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
>>> video = upsample_pipe(
|
||||
... video=video,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
... fps=frame_rate,
|
||||
... audio=audio[0].float().cpu(),
|
||||
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
|
||||
... output_path="video.mp4",
|
||||
... )
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "vae->latent_upsampler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
latent_upsampler: LTX2LatentUpsamplerModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_frames: int = 121,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
latents = self._unpack_latents(
|
||||
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
|
||||
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
"""
|
||||
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
latent (`torch.Tensor`):
|
||||
Input latents to normalize
|
||||
reference_latents (`torch.Tensor`):
|
||||
The reference latents providing style statistics.
|
||||
factor (`float`):
|
||||
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed latent tensor
|
||||
"""
|
||||
result = latents.clone()
|
||||
|
||||
for i in range(latents.size(0)):
|
||||
for c in range(latents.size(1)):
|
||||
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
|
||||
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
|
||||
|
||||
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
|
||||
|
||||
result = torch.lerp(latents, result, factor)
|
||||
return result
|
||||
|
||||
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
|
||||
"""
|
||||
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
|
||||
smooth way using a sigmoid-based compression.
|
||||
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
|
||||
when controlling dynamic behavior with a `compression` factor.
|
||||
|
||||
Args:
|
||||
latents : torch.Tensor
|
||||
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
|
||||
compression : float
|
||||
Compression strength in the range [0, 1].
|
||||
- 0.0: No tone-mapping (identity transform)
|
||||
- 1.0: Full compression effect
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
The tone-mapped latent tensor of the same shape as input.
|
||||
"""
|
||||
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
|
||||
scale_factor = compression * 0.75
|
||||
abs_latents = torch.abs(latents)
|
||||
|
||||
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
|
||||
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
|
||||
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||
|
||||
filtered = latents * scales
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
|
||||
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` can be provided.")
|
||||
if video is None and latents is None:
|
||||
raise ValueError("One of `video` or `latents` has to be provided.")
|
||||
|
||||
if not (0 <= tone_map_compression_ratio <= 1):
|
||||
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
latents_normalized: bool = False,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
adain_factor: float = 0.0,
|
||||
tone_map_compression_ratio: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`, *optional*)
|
||||
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
|
||||
supplied.
|
||||
height (`int`, *optional*, defaults to `512`):
|
||||
The height in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
width (`int`, *optional*, defaults to `768`):
|
||||
The width in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
num_frames (`int`, *optional*, defaults to `121`):
|
||||
The number of frames in the input video.
|
||||
spatial_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
|
||||
necessary.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
|
||||
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
|
||||
latent_channels, latent_frames, latent_height, latent_width)`.
|
||||
latents_normalized (`bool`, *optional*, defaults to `False`)
|
||||
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
|
||||
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
adain_factor (`float`, *optional*, defaults to `0.0`):
|
||||
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
|
||||
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
|
||||
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
|
||||
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
|
||||
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
|
||||
the full compression effect.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the upsampled video.
|
||||
"""
|
||||
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
latents=latents,
|
||||
tone_map_compression_ratio=tone_map_compression_ratio,
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
# Batched video input is not yet tested/supported. TODO: take a look later
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = self._execution_device
|
||||
|
||||
if video is not None:
|
||||
num_frames = len(video)
|
||||
if num_frames % self.vae_temporal_compression_ratio != 1:
|
||||
num_frames = (
|
||||
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
|
||||
)
|
||||
video = video[:num_frames]
|
||||
logger.warning(
|
||||
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents_supplied = latents is not None
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
spatial_patch_size=spatial_patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if latents_supplied and latents_normalized:
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents_upsampled = self.latent_upsampler(latents)
|
||||
|
||||
if adain_factor > 0.0:
|
||||
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
|
||||
else:
|
||||
latents = latents_upsampled
|
||||
|
||||
if tone_map_compression_ratio > 0.0:
|
||||
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2PipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for LTX pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
audio (`torch.Tensor`, `np.ndarray`):
|
||||
TODO
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
audio: torch.Tensor
|
||||
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: Tuple[int, ...] = (1, 3, 5),
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1024,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
|
||||
upsample_factors: List[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_kernel_sizes)
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
|
||||
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
|
||||
)
|
||||
|
||||
if self.resnets_per_upsample != len(resnet_dilations):
|
||||
raise ValueError(
|
||||
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
self.resnets = nn.ModuleList()
|
||||
input_channels = hidden_channels
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
output_channels = input_channels // 2
|
||||
self.upsamplers.append(
|
||||
nn.ConvTranspose1d(
|
||||
input_channels, # hidden_channels // (2 ** i)
|
||||
output_channels, # hidden_channels // (2 ** (i + 1))
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
|
||||
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
|
||||
`True`.
|
||||
time_last (`bool`, *optional*, defaults to `False`):
|
||||
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
|
||||
# Ensure that the time/frame dimension is last
|
||||
if not time_last:
|
||||
hidden_states = hidden_states.transpose(2, 3)
|
||||
# Combine channels and frequency (mel bins) dimensions
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
start = i * self.resnets_per_upsample
|
||||
end = (i + 1) * self.resnets_per_upsample
|
||||
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -109,7 +109,7 @@ LIBRARIES = []
|
||||
for library in LOADABLE_CLASSES:
|
||||
LIBRARIES.append(library)
|
||||
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device()]
|
||||
SUPPORTED_DEVICE_MAP = ["balanced"] + [get_device(), "cpu"]
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
@@ -462,8 +462,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
pipeline_is_sequentially_offloaded = any(
|
||||
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
||||
)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` to remove the existing device map from the pipeline."
|
||||
@@ -1164,7 +1163,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
self._maybe_raise_error_if_group_offload_active(raise_error=True)
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
||||
@@ -1286,7 +1285,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
||||
self.remove_all_hooks()
|
||||
|
||||
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
||||
is_pipeline_device_mapped = self._is_pipeline_device_mapped()
|
||||
if is_pipeline_device_mapped:
|
||||
raise ValueError(
|
||||
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
||||
@@ -2171,6 +2170,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_pipeline_device_mapped(self):
|
||||
# We support passing `device_map="cuda"`, for example. This is helpful, in case
|
||||
# users want to pass `device_map="cpu"` when initializing a pipeline. This explicit declaration is desirable
|
||||
# in limited VRAM environments because quantized models often initialize directly on the accelerator.
|
||||
device_map = self.hf_device_map
|
||||
is_device_type_map = False
|
||||
if isinstance(device_map, str):
|
||||
try:
|
||||
torch.device(device_map)
|
||||
is_device_type_map = True
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return not is_device_type_map and isinstance(device_map, dict) and len(device_map) > 1
|
||||
|
||||
|
||||
class StableDiffusionMixin:
|
||||
r"""
|
||||
|
||||
@@ -143,7 +143,20 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
@@ -155,7 +168,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigma.atan() / math.pi * 2
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -173,13 +206,13 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -242,8 +275,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.noise_sampler = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -254,10 +306,27 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -354,7 +423,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -540,7 +612,10 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
[g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed()
|
||||
)
|
||||
self.noise_sampler = BrownianTreeNoiseSampler(
|
||||
model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed
|
||||
model_output,
|
||||
sigma_min=self.config.sigma_min,
|
||||
sigma_max=self.config.sigma_max,
|
||||
seed=seed,
|
||||
)
|
||||
noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to(
|
||||
model_output.device
|
||||
@@ -612,7 +687,18 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
|
||||
@@ -175,13 +175,37 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self._begin_index = begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_noise
|
||||
def precondition_noise(self, sigma):
|
||||
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the noise level by applying a logarithmic transformation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The sigma (noise level) value to precondition.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preconditioned noise value computed as `0.25 * log(sigma)`.
|
||||
"""
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
@@ -190,7 +214,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return c_noise
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -208,13 +252,13 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -274,8 +318,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -286,10 +349,27 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -433,7 +513,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
The sample tensor at the previous timestep.
|
||||
"""
|
||||
sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
||||
sigma_t, sigma_s = (
|
||||
self.sigmas[self.step_index + 1],
|
||||
self.sigmas[self.step_index],
|
||||
)
|
||||
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
||||
alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)
|
||||
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
||||
@@ -684,7 +767,10 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if self.config.algorithm_type == "sde-dpmsolver++":
|
||||
noise = randn_tensor(
|
||||
model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
|
||||
model_output.shape,
|
||||
generator=generator,
|
||||
device=model_output.device,
|
||||
dtype=model_output.dtype,
|
||||
)
|
||||
else:
|
||||
noise = None
|
||||
@@ -757,7 +843,18 @@ class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
return noisy_samples
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._get_conditioning_c_in
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Literal, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -57,29 +57,28 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
sigma_min (`float`, *optional*, defaults to 0.002):
|
||||
sigma_min (`float`, *optional*, defaults to `0.002`):
|
||||
Minimum noise magnitude in the sigma schedule. This was set to 0.002 in the EDM paper [1]; a reasonable
|
||||
range is [0, 10].
|
||||
sigma_max (`float`, *optional*, defaults to 80.0):
|
||||
sigma_max (`float`, *optional*, defaults to `80.0`):
|
||||
Maximum noise magnitude in the sigma schedule. This was set to 80.0 in the EDM paper [1]; a reasonable
|
||||
range is [0.2, 80.0].
|
||||
sigma_data (`float`, *optional*, defaults to 0.5):
|
||||
sigma_data (`float`, *optional*, defaults to `0.5`):
|
||||
The standard deviation of the data distribution. This is set to 0.5 in the EDM paper [1].
|
||||
sigma_schedule (`str`, *optional*, defaults to `karras`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
|
||||
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
|
||||
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, defaults to 1000):
|
||||
sigma_schedule (`Literal["karras", "exponential"]`, *optional*, defaults to `"karras"`):
|
||||
Sigma schedule to compute the `sigmas`. By default, we use the schedule introduced in the EDM paper
|
||||
(https://huggingface.co/papers/2206.00364). The `"exponential"` schedule was incorporated in this model:
|
||||
https://huggingface.co/stabilityai/cosxl.
|
||||
num_train_timesteps (`int`, *optional*, defaults to `1000`):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `epsilon`, *optional*):
|
||||
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
|
||||
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
|
||||
Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to 7.0):
|
||||
prediction_type (`Literal["epsilon", "v_prediction"]`, *optional*, defaults to `"epsilon"`):
|
||||
Prediction type of the scheduler function. `"epsilon"` predicts the noise of the diffusion process, and
|
||||
`"v_prediction"` (see section 2.4 of [Imagen Video](https://huggingface.co/papers/2210.02303) paper).
|
||||
rho (`float`, *optional*, defaults to `7.0`):
|
||||
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
|
||||
final_sigmas_type (`str`, defaults to `"zero"`):
|
||||
final_sigmas_type (`Literal["zero", "sigma_min"]`, *optional*, defaults to `"zero"`):
|
||||
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
|
||||
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
|
||||
sigma is the same as the last sigma in the training schedule. If `"zero"`, the final sigma is set to 0.
|
||||
"""
|
||||
|
||||
_compatibles = []
|
||||
@@ -91,12 +90,12 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigma_min: float = 0.002,
|
||||
sigma_max: float = 80.0,
|
||||
sigma_data: float = 0.5,
|
||||
sigma_schedule: str = "karras",
|
||||
sigma_schedule: Literal["karras", "exponential"] = "karras",
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "epsilon",
|
||||
prediction_type: Literal["epsilon", "v_prediction"] = "epsilon",
|
||||
rho: float = 7.0,
|
||||
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
|
||||
):
|
||||
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
|
||||
) -> None:
|
||||
if sigma_schedule not in ["karras", "exponential"]:
|
||||
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
|
||||
|
||||
@@ -131,26 +130,41 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
# standard deviation of the initial noise distribution
|
||||
def init_noise_sigma(self) -> float:
|
||||
"""
|
||||
Return the standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`float`:
|
||||
The initial noise sigma value computed as `(sigma_max**2 + 1) ** 0.5`.
|
||||
"""
|
||||
return (self.config.sigma_max**2 + 1) ** 0.5
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
def step_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index counter for current timestep. It will increase 1 after each scheduler step.
|
||||
Return the index counter for the current timestep. The index will increase by 1 after each scheduler step.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The current step index, or `None` if not yet initialized.
|
||||
"""
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
def begin_index(self) -> Optional[int]:
|
||||
"""
|
||||
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
||||
Return the index for the first timestep. This should be set from the pipeline with the `set_begin_index`
|
||||
method.
|
||||
|
||||
Returns:
|
||||
`int` or `None`:
|
||||
The begin index, or `None` if not yet set.
|
||||
"""
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
def set_begin_index(self, begin_index: int = 0) -> None:
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
@@ -160,12 +174,36 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def precondition_inputs(self, sample, sigma):
|
||||
def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the input sample by scaling it according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor to precondition.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The scaled input sample.
|
||||
"""
|
||||
c_in = self._get_conditioning_c_in(sigma)
|
||||
scaled_sample = sample * c_in
|
||||
return scaled_sample
|
||||
|
||||
def precondition_noise(self, sigma):
|
||||
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the noise level by applying a logarithmic transformation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The sigma (noise level) value to precondition.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The preconditioned noise value computed as `0.25 * log(sigma)`.
|
||||
"""
|
||||
if not isinstance(sigma, torch.Tensor):
|
||||
sigma = torch.tensor([sigma])
|
||||
|
||||
@@ -173,7 +211,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return c_noise
|
||||
|
||||
def precondition_outputs(self, sample, model_output, sigma):
|
||||
def precondition_outputs(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
model_output: torch.Tensor,
|
||||
sigma: Union[float, torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Precondition the model outputs according to the EDM formulation.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample tensor.
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The denoised sample computed by combining the skip connection and output scaling.
|
||||
"""
|
||||
sigma_data = self.config.sigma_data
|
||||
c_skip = sigma_data**2 / (sigma**2 + sigma_data**2)
|
||||
|
||||
@@ -190,13 +248,13 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm.
|
||||
Scale the denoising model input to match the Euler algorithm. Ensures interchangeability with schedulers that
|
||||
need to scale the denoising model input depending on the current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor`):
|
||||
The input sample.
|
||||
timestep (`int`, *optional*):
|
||||
The input sample tensor.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
|
||||
Returns:
|
||||
@@ -214,19 +272,19 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
|
||||
sigmas (`torch.Tensor` or `List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process. If not defined, the default behavior when
|
||||
`num_inference_steps` is passed will be used.
|
||||
"""
|
||||
@@ -262,8 +320,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
||||
|
||||
# Taken from https://github.com/crowsonkb/k-diffusion/blob/686dbad0f39640ea25c8a8c6a6e56bb40eacefa2/k_diffusion/sampling.py#L17
|
||||
def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Constructs the noise schedule of Karras et al. (2022)."""
|
||||
def _compute_karras_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Construct the noise schedule of [Karras et al. (2022)](https://huggingface.co/papers/2206.00364).
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values in [0, 1] representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed Karras sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
|
||||
@@ -273,10 +350,27 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
||||
return sigmas
|
||||
|
||||
def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor:
|
||||
"""Implementation closely follows k-diffusion.
|
||||
|
||||
def _compute_exponential_sigmas(
|
||||
self,
|
||||
ramp: torch.Tensor,
|
||||
sigma_min: Optional[float] = None,
|
||||
sigma_max: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Compute the exponential sigma schedule. Implementation closely follows k-diffusion:
|
||||
https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26
|
||||
|
||||
Args:
|
||||
ramp (`torch.Tensor`):
|
||||
A tensor of values representing the interpolation positions.
|
||||
sigma_min (`float`, *optional*):
|
||||
Minimum sigma value. If `None`, uses `self.config.sigma_min`.
|
||||
sigma_max (`float`, *optional*):
|
||||
Maximum sigma value. If `None`, uses `self.config.sigma_max`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The computed exponential sigma schedule.
|
||||
"""
|
||||
sigma_min = sigma_min or self.config.sigma_min
|
||||
sigma_max = sigma_max or self.config.sigma_max
|
||||
@@ -342,32 +436,38 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
pred_original_sample: Optional[torch.Tensor] = None,
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple]:
|
||||
) -> Union[EDMEulerSchedulerOutput, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
"""
|
||||
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from learned diffusion model.
|
||||
timestep (`float`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current discrete timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
s_churn (`float`):
|
||||
s_tmin (`float`):
|
||||
s_tmax (`float`):
|
||||
s_noise (`float`, defaults to 1.0):
|
||||
s_churn (`float`, *optional*, defaults to `0.0`):
|
||||
The amount of stochasticity to add at each step. Higher values add more noise.
|
||||
s_tmin (`float`, *optional*, defaults to `0.0`):
|
||||
The minimum sigma threshold below which no noise is added.
|
||||
s_tmax (`float`, *optional*, defaults to `float("inf")`):
|
||||
The maximum sigma threshold above which no noise is added.
|
||||
s_noise (`float`, *optional*, defaults to `1.0`):
|
||||
Scaling factor for noise added to the sample.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
return_dict (`bool`):
|
||||
Whether or not to return a [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or tuple.
|
||||
A random number generator for reproducibility.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or tuple.
|
||||
pred_original_sample (`torch.Tensor`, *optional*):
|
||||
The predicted denoised sample from a previous step. If provided, skips recomputation.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EDMEulerSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
||||
[`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, an [`~schedulers.scheduling_edm_euler.EDMEulerSchedulerOutput`] is
|
||||
returned, otherwise a tuple is returned where the first element is the previous sample tensor and the
|
||||
second element is the predicted original sample tensor.
|
||||
"""
|
||||
|
||||
if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
|
||||
@@ -399,7 +499,10 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
if gamma > 0:
|
||||
noise = randn_tensor(
|
||||
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
||||
model_output.shape,
|
||||
dtype=model_output.dtype,
|
||||
device=model_output.device,
|
||||
generator=generator,
|
||||
)
|
||||
eps = noise * s_noise
|
||||
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
||||
@@ -478,9 +581,20 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
|
||||
noisy_samples = original_samples + noise * sigma
|
||||
return noisy_samples
|
||||
|
||||
def _get_conditioning_c_in(self, sigma):
|
||||
def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
||||
"""
|
||||
Compute the input conditioning factor for the EDM formulation.
|
||||
|
||||
Args:
|
||||
sigma (`float` or `torch.Tensor`):
|
||||
The current sigma (noise level) value.
|
||||
|
||||
Returns:
|
||||
`float` or `torch.Tensor`:
|
||||
The input conditioning factor `c_in`.
|
||||
"""
|
||||
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
|
||||
return c_in
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
@@ -66,6 +66,7 @@ from .import_utils import (
|
||||
is_accelerate_version,
|
||||
is_aiter_available,
|
||||
is_aiter_version,
|
||||
is_av_available,
|
||||
is_better_profanity_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_version,
|
||||
|
||||
@@ -502,6 +502,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Video(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -967,21 +997,6 @@ class HiDreamImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class GlmImageTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class HunyuanDiT2DControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -1162,6 +1177,21 @@ class LongCatImageTransformer2DModel(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTX2VideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class LTXVideoTransformer3DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1877,6 +1877,51 @@ class LongCatImagePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2LatentUpsamplePipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTX2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LTXConditionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -230,6 +230,7 @@ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_at
|
||||
_aiter_available, _aiter_version = _is_package_available("aiter")
|
||||
_kornia_available, _kornia_version = _is_package_available("kornia")
|
||||
_nvidia_modelopt_available, _nvidia_modelopt_version = _is_package_available("modelopt", get_dist_name=True)
|
||||
_av_available, _av_version = _is_package_available("av")
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
@@ -420,6 +421,10 @@ def is_kornia_available():
|
||||
return _kornia_available
|
||||
|
||||
|
||||
def is_av_available():
|
||||
return _av_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
from ...testing_utils import (
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
class AutoencoderKLLTX2AudioTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Audio
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 2, # stereo,
|
||||
"output_channels": 2,
|
||||
"latent_channels": 4,
|
||||
"base_channels": 16,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"resolution": 16,
|
||||
"attn_resolutions": None,
|
||||
"num_res_blocks": 2,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"mel_bins": 16,
|
||||
"is_causal": True,
|
||||
"double_z": True,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_channels = 2
|
||||
num_frames = 8
|
||||
num_mel_bins = 16
|
||||
|
||||
spectrogram = floats_tensor((batch_size, num_channels, num_frames, num_mel_bins)).to(torch_device)
|
||||
|
||||
input_dict = {"sample": spectrogram}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (2, 5, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
# Overriding as output shape is not the same as input shape for LTX 2.0 audio VAE
|
||||
def test_output(self):
|
||||
super().test_output(expected_output_shape=(2, 2, 5, 16))
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTX2Audio does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
103
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
103
tests/models/autoencoders/test_models_autoencoder_ltx2_video.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
from ...testing_utils import (
|
||||
enable_full_determinism,
|
||||
floats_tensor,
|
||||
torch_device,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
from .testing_utils import AutoencoderTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class AutoencoderKLLTX2VideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
|
||||
model_class = AutoencoderKLLTX2Video
|
||||
main_input_name = "sample"
|
||||
base_precision = 1e-2
|
||||
|
||||
def get_autoencoder_kl_ltx_video_config(self):
|
||||
return {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 8,
|
||||
"block_out_channels": (8, 8, 8, 8),
|
||||
"decoder_block_out_channels": (16, 32, 64),
|
||||
"layers_per_block": (1, 1, 1, 1, 1),
|
||||
"decoder_layers_per_block": (1, 1, 1, 1),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
# Full model uses `reflect` but this does not have deterministic backward implementation, so use `zeros`
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 2
|
||||
num_frames = 9
|
||||
num_channels = 3
|
||||
sizes = (16, 16)
|
||||
|
||||
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
|
||||
|
||||
input_dict = {"sample": image}
|
||||
return input_dict
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (3, 9, 16, 16)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = self.get_autoencoder_kl_ltx_video_config()
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {
|
||||
"LTX2VideoEncoder3d",
|
||||
"LTX2VideoDecoder3d",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoMidBlock3d",
|
||||
"LTX2VideoUpBlock3d",
|
||||
}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
@unittest.skip("Unsupported test.")
|
||||
def test_outputs_equivalence(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("AutoencoderKLLTXVideo does not support `norm_num_groups` because it does not use GroupNorm.")
|
||||
def test_forward_with_norm_groups(self):
|
||||
pass
|
||||
222
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
222
tests/models/transformers/test_models_transformer_ltx2.py
Normal file
@@ -0,0 +1,222 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
# Common
|
||||
batch_size = 2
|
||||
|
||||
# Video
|
||||
num_frames = 2
|
||||
num_channels = 4
|
||||
height = 16
|
||||
width = 16
|
||||
|
||||
# Audio
|
||||
audio_num_frames = 9
|
||||
audio_num_channels = 2
|
||||
num_mel_bins = 2
|
||||
|
||||
# Text
|
||||
embedding_dim = 16
|
||||
sequence_length = 16
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_frames * height * width, num_channels)).to(torch_device)
|
||||
audio_hidden_states = torch.randn((batch_size, audio_num_frames, audio_num_channels * num_mel_bins)).to(
|
||||
torch_device
|
||||
)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
audio_encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
encoder_attention_mask = torch.ones((batch_size, sequence_length)).bool().to(torch_device)
|
||||
timestep = torch.rand((batch_size,)).to(torch_device) * 1000
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"audio_hidden_states": audio_hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"audio_encoder_hidden_states": audio_encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"num_frames": num_frames,
|
||||
"height": height,
|
||||
"width": width,
|
||||
"audio_num_frames": audio_num_frames,
|
||||
"fps": 25.0,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (512, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"num_layers": 2,
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"caption_channels": 16,
|
||||
"rope_double_precision": False,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"LTX2VideoTransformer3DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
# def test_ltx2_consistency(self, seed=0, dtype=torch.float32):
|
||||
# torch.manual_seed(seed)
|
||||
# init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# # Calculate dummy inputs in a custom manner to ensure compatibility with original code
|
||||
# batch_size = 2
|
||||
# num_frames = 9
|
||||
# latent_frames = 2
|
||||
# text_embedding_dim = 16
|
||||
# text_seq_len = 16
|
||||
# fps = 25.0
|
||||
# sampling_rate = 16000.0
|
||||
# hop_length = 160.0
|
||||
|
||||
# sigma = torch.rand((1,), generator=torch.manual_seed(seed), dtype=dtype, device="cpu") * 1000
|
||||
# timestep = (sigma * torch.ones((batch_size,), dtype=dtype, device="cpu")).to(device=torch_device)
|
||||
|
||||
# num_channels = 4
|
||||
# latent_height = 4
|
||||
# latent_width = 4
|
||||
# hidden_states = torch.randn(
|
||||
# (batch_size, num_channels, latent_frames, latent_height, latent_width),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
# # Patchify video latents (with patch_size (1, 1, 1))
|
||||
# hidden_states = hidden_states.reshape(batch_size, -1, latent_frames, 1, latent_height, 1, latent_width, 1)
|
||||
# hidden_states = hidden_states.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
# encoder_hidden_states = torch.randn(
|
||||
# (batch_size, text_seq_len, text_embedding_dim),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
|
||||
# audio_num_channels = 2
|
||||
# num_mel_bins = 2
|
||||
# latent_length = int((sampling_rate / hop_length / 4) * (num_frames / fps))
|
||||
# audio_hidden_states = torch.randn(
|
||||
# (batch_size, audio_num_channels, latent_length, num_mel_bins),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
# # Patchify audio latents
|
||||
# audio_hidden_states = audio_hidden_states.transpose(1, 2).flatten(2, 3)
|
||||
# audio_encoder_hidden_states = torch.randn(
|
||||
# (batch_size, text_seq_len, text_embedding_dim),
|
||||
# generator=torch.manual_seed(seed),
|
||||
# dtype=dtype,
|
||||
# device="cpu",
|
||||
# )
|
||||
|
||||
# inputs_dict = {
|
||||
# "hidden_states": hidden_states.to(device=torch_device),
|
||||
# "audio_hidden_states": audio_hidden_states.to(device=torch_device),
|
||||
# "encoder_hidden_states": encoder_hidden_states.to(device=torch_device),
|
||||
# "audio_encoder_hidden_states": audio_encoder_hidden_states.to(device=torch_device),
|
||||
# "timestep": timestep,
|
||||
# "num_frames": latent_frames,
|
||||
# "height": latent_height,
|
||||
# "width": latent_width,
|
||||
# "audio_num_frames": num_frames,
|
||||
# "fps": 25.0,
|
||||
# }
|
||||
|
||||
# model = self.model_class.from_pretrained(
|
||||
# "diffusers-internal-dev/dummy-ltx2",
|
||||
# subfolder="transformer",
|
||||
# device_map="cpu",
|
||||
# )
|
||||
# # torch.manual_seed(seed)
|
||||
# # model = self.model_class(**init_dict)
|
||||
# model.to(torch_device)
|
||||
# model.eval()
|
||||
|
||||
# with attention_backend("native"):
|
||||
# with torch.no_grad():
|
||||
# output = model(**inputs_dict)
|
||||
|
||||
# video_output, audio_output = output.to_tuple()
|
||||
|
||||
# self.assertIsNotNone(video_output)
|
||||
# self.assertIsNotNone(audio_output)
|
||||
|
||||
# # input & output have to have the same shape
|
||||
# video_expected_shape = (batch_size, latent_frames * latent_height * latent_width, num_channels)
|
||||
# self.assertEqual(video_output.shape, video_expected_shape, "Video input and output shapes do not match")
|
||||
# audio_expected_shape = (batch_size, latent_length, audio_num_channels * num_mel_bins)
|
||||
# self.assertEqual(audio_output.shape, audio_expected_shape, "Audio input and output shapes do not match")
|
||||
|
||||
# # Check against expected slice
|
||||
# # fmt: off
|
||||
# video_expected_slice = torch.tensor([0.4783, 1.6954, -1.2092, 0.1762, 0.7801, 1.2025, -1.4525, -0.2721, 0.3354, 1.9144, -1.5546, 0.0831, 0.4391, 1.7012, -1.7373, -0.2676])
|
||||
# audio_expected_slice = torch.tensor([-0.4236, 0.4750, 0.3901, -0.4339, -0.2782, 0.4357, 0.4526, -0.3927, -0.0980, 0.4870, 0.3964, -0.3169, -0.3974, 0.4408, 0.3809, -0.4692])
|
||||
# # fmt: on
|
||||
|
||||
# video_output_flat = video_output.cpu().flatten().float()
|
||||
# video_generated_slice = torch.cat([video_output_flat[:8], video_output_flat[-8:]])
|
||||
# self.assertTrue(torch.allclose(video_generated_slice, video_expected_slice, atol=1e-4))
|
||||
|
||||
# audio_output_flat = audio_output.cpu().flatten().float()
|
||||
# audio_generated_slice = torch.cat([audio_output_flat[:8], audio_output_flat[-8:]])
|
||||
# self.assertTrue(torch.allclose(audio_generated_slice, audio_expected_slice, atol=1e-4))
|
||||
|
||||
|
||||
class LTX2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = LTX2VideoTransformer3DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return LTX2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
0
tests/pipelines/ltx2/__init__.py
Normal file
0
tests/pipelines/ltx2/__init__.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
239
tests/pipelines/ltx2/test_ltx2.py
Normal file
@@ -0,0 +1,239 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2Pipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"output_type",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.4331, 0.6203, 0.3245, 0.7294, 0.4822, 0.5703, 0.2999, 0.7700, 0.4961, 0.4242, 0.4581, 0.4351, 0.1137, 0.4437, 0.6304, 0.3184
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
241
tests/pipelines/ltx2/test_ltx2_image2video.py
Normal file
@@ -0,0 +1,241 @@
|
||||
# Copyright 2025 The HuggingFace Team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2TextConnectors
|
||||
from diffusers.pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
|
||||
from ...testing_utils import enable_full_determinism
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = LTX2ImageToVideoPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image"})
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
required_optional_params = frozenset(
|
||||
[
|
||||
"num_inference_steps",
|
||||
"generator",
|
||||
"latents",
|
||||
"audio_latents",
|
||||
"return_dict",
|
||||
"callback_on_step_end",
|
||||
"callback_on_step_end_tensor_inputs",
|
||||
]
|
||||
)
|
||||
test_attention_slicing = False
|
||||
test_xformers_attention = False
|
||||
supports_dduf = False
|
||||
|
||||
base_text_encoder_ckpt_id = "hf-internal-testing/tiny-gemma3"
|
||||
|
||||
def get_dummy_components(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(self.base_text_encoder_ckpt_id)
|
||||
|
||||
torch.manual_seed(0)
|
||||
transformer = LTX2VideoTransformer3DModel(
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=8,
|
||||
cross_attention_dim=16,
|
||||
audio_in_channels=4,
|
||||
audio_out_channels=4,
|
||||
audio_num_attention_heads=2,
|
||||
audio_attention_head_dim=4,
|
||||
audio_cross_attention_dim=8,
|
||||
num_layers=2,
|
||||
qk_norm="rms_norm_across_heads",
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
rope_double_precision=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
connectors = LTX2TextConnectors(
|
||||
caption_channels=text_encoder.config.text_config.hidden_size,
|
||||
text_proj_in_factor=text_encoder.config.text_config.num_hidden_layers + 1,
|
||||
video_connector_num_attention_heads=4,
|
||||
video_connector_attention_head_dim=8,
|
||||
video_connector_num_layers=1,
|
||||
video_connector_num_learnable_registers=None,
|
||||
audio_connector_num_attention_heads=4,
|
||||
audio_connector_attention_head_dim=8,
|
||||
audio_connector_num_layers=1,
|
||||
audio_connector_num_learnable_registers=None,
|
||||
connector_rope_base_seq_len=32,
|
||||
rope_theta=10000.0,
|
||||
rope_double_precision=False,
|
||||
causal_temporal_positioning=False,
|
||||
rope_type="split",
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLLTX2Video(
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
latent_channels=4,
|
||||
block_out_channels=(8,),
|
||||
decoder_block_out_channels=(8,),
|
||||
layers_per_block=(1,),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
spatio_temporal_scaling=(True,),
|
||||
decoder_spatio_temporal_scaling=(True,),
|
||||
decoder_inject_noise=(False, False),
|
||||
downsample_type=("spatial",),
|
||||
upsample_residual=(False,),
|
||||
upsample_factor=(1,),
|
||||
timestep_conditioning=False,
|
||||
patch_size=1,
|
||||
patch_size_t=1,
|
||||
encoder_causal=True,
|
||||
decoder_causal=False,
|
||||
)
|
||||
vae.use_framewise_encoding = False
|
||||
vae.use_framewise_decoding = False
|
||||
|
||||
torch.manual_seed(0)
|
||||
audio_vae = AutoencoderKLLTX2Audio(
|
||||
base_channels=4,
|
||||
output_channels=2,
|
||||
ch_mult=(1,),
|
||||
num_res_blocks=1,
|
||||
attn_resolutions=None,
|
||||
in_channels=2,
|
||||
resolution=32,
|
||||
latent_channels=2,
|
||||
norm_type="pixel",
|
||||
causality_axis="height",
|
||||
dropout=0.0,
|
||||
mid_block_add_attention=False,
|
||||
sample_rate=16000,
|
||||
mel_hop_length=160,
|
||||
is_causal=True,
|
||||
mel_bins=8,
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vocoder = LTX2Vocoder(
|
||||
in_channels=audio_vae.config.output_channels * audio_vae.config.mel_bins,
|
||||
hidden_channels=32,
|
||||
out_channels=2,
|
||||
upsample_kernel_sizes=[4, 4],
|
||||
upsample_factors=[2, 2],
|
||||
resnet_kernel_sizes=[3],
|
||||
resnet_dilations=[[1, 3, 5]],
|
||||
leaky_relu_negative_slope=0.1,
|
||||
output_sampling_rate=16000,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
components = {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"audio_vae": audio_vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
|
||||
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
|
||||
|
||||
inputs = {
|
||||
"image": image,
|
||||
"prompt": "a robot dancing",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"num_frames": 5,
|
||||
"frame_rate": 25.0,
|
||||
"max_sequence_length": 16,
|
||||
"output_type": "pt",
|
||||
}
|
||||
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = pipe(**inputs)
|
||||
video = output.frames
|
||||
audio = output.audio
|
||||
|
||||
self.assertEqual(video.shape, (1, 5, 3, 32, 32))
|
||||
self.assertEqual(audio.shape[0], 1)
|
||||
self.assertEqual(audio.shape[1], components["vocoder"].config.out_channels)
|
||||
|
||||
# fmt: off
|
||||
expected_video_slice = torch.tensor(
|
||||
[
|
||||
0.3573, 0.8382, 0.3581, 0.6114, 0.3682, 0.7969, 0.2552, 0.6399, 0.3113, 0.1497, 0.3249, 0.5395, 0.3498, 0.4526, 0.4536, 0.4555
|
||||
]
|
||||
)
|
||||
expected_audio_slice = torch.tensor(
|
||||
[
|
||||
0.0236, 0.0499, 0.1230, 0.1094, 0.1713, 0.1044, 0.1729, 0.1009, 0.0672, -0.0069, 0.0688, 0.0097, 0.0808, 0.1231, 0.0986, 0.0739
|
||||
]
|
||||
)
|
||||
# fmt: on
|
||||
|
||||
video = video.flatten()
|
||||
audio = audio.flatten()
|
||||
generated_video_slice = torch.cat([video[:8], video[-8:]])
|
||||
generated_audio_slice = torch.cat([audio[:8], audio[-8:]])
|
||||
|
||||
assert torch.allclose(expected_video_slice, generated_video_slice, atol=1e-4, rtol=1e-4)
|
||||
assert torch.allclose(expected_audio_slice, generated_audio_slice, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(batch_size=2, expected_max_diff=2e-2)
|
||||
Reference in New Issue
Block a user