mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-17 05:55:59 +08:00
Compare commits
44 Commits
z-image-te
...
ltx2-3-pip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6ee66c9d66 | ||
|
|
e5aa719241 | ||
|
|
17b53f0866 | ||
|
|
93247a0a36 | ||
|
|
8a5807349d | ||
|
|
145e8e48c8 | ||
|
|
f1a812aa6a | ||
|
|
4bc1c59a67 | ||
|
|
89f8cc4384 | ||
|
|
764f7ede33 | ||
|
|
6188af2215 | ||
|
|
8d0f3e1ba8 | ||
|
|
094caf398f | ||
|
|
63b3c9f223 | ||
|
|
f78c3dae5a | ||
|
|
ab0e5b5cbb | ||
|
|
81c354d879 | ||
|
|
0a2c26d0a4 | ||
|
|
07c5ba8eee | ||
|
|
c0bb2ef21f | ||
|
|
d018534de1 | ||
|
|
652d363ded | ||
|
|
f875031d2f | ||
|
|
897aed72fa | ||
|
|
5056aa8203 | ||
|
|
de3f869b5c | ||
|
|
fbb50d964d | ||
|
|
e719d32c63 | ||
|
|
420628039a | ||
|
|
50da4df0ba | ||
|
|
c5e1fcc4b7 | ||
|
|
0528fde41d | ||
|
|
13292dde4d | ||
|
|
4dfcfeb3aa | ||
|
|
19004efc6c | ||
|
|
835bed615d | ||
|
|
4ff31688c7 | ||
|
|
5a44adb0b0 | ||
|
|
1e89cb3652 | ||
|
|
236eb8db64 | ||
|
|
cde67486cf | ||
|
|
f768f8dae8 | ||
|
|
e90b90a3cc | ||
|
|
6c7e720dd8 |
77
.ai/AGENTS.md
Normal file
77
.ai/AGENTS.md
Normal file
@@ -0,0 +1,77 @@
|
||||
# Diffusers — Agent Guide
|
||||
|
||||
## Coding style
|
||||
|
||||
Strive to write code as simple and explicit as possible.
|
||||
|
||||
- Minimize small helper/utility functions — inline the logic instead. A reader should be able to follow the full flow without jumping between functions.
|
||||
- No defensive code or unused code paths — do not add fallback paths, safety checks, or configuration options "just in case". When porting from a research repo, delete training-time code paths, experimental flags, and ablation branches entirely — only keep the inference path you are actually integrating.
|
||||
- Do not guess user intent and silently correct behavior. Make the expected inputs clear in the docstring, and raise a concise error for unsupported cases rather than adding complex fallback logic.
|
||||
|
||||
---
|
||||
|
||||
### Dependencies
|
||||
- No new mandatory dependency without discussion (e.g. `einops`)
|
||||
- Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`
|
||||
|
||||
## Code formatting
|
||||
- `make style` and `make fix-copies` should be run as the final step before opening a PR
|
||||
|
||||
### Copied Code
|
||||
- Many classes are kept in sync with a source via a `# Copied from ...` header comment
|
||||
- Do not edit a `# Copied from` block directly — run `make fix-copies` to propagate changes from the source
|
||||
- Remove the header to intentionally break the link
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations.
|
||||
- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Pipeline
|
||||
- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`).
|
||||
|
||||
|
||||
### Tests
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -178,4 +178,8 @@ tags
|
||||
.ruff_cache
|
||||
|
||||
# wandb
|
||||
wandb
|
||||
wandb
|
||||
|
||||
# AI agent generated symlinks
|
||||
/AGENTS.md
|
||||
/CLAUDE.md
|
||||
13
Makefile
13
Makefile
@@ -1,4 +1,4 @@
|
||||
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples
|
||||
.PHONY: deps_table_update modified_only_fixup extra_style_checks quality style fixup fix-copies test test-examples codex claude clean-ai
|
||||
|
||||
# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
|
||||
export PYTHONPATH = src
|
||||
@@ -98,3 +98,14 @@ post-release:
|
||||
|
||||
post-patch:
|
||||
python utils/release.py --post_release --patch
|
||||
|
||||
# AI agent symlinks
|
||||
|
||||
codex:
|
||||
ln -snf .ai/AGENTS.md AGENTS.md
|
||||
|
||||
claude:
|
||||
ln -snf .ai/AGENTS.md CLAUDE.md
|
||||
|
||||
clean-ai:
|
||||
rm -f AGENTS.md CLAUDE.md
|
||||
|
||||
@@ -532,8 +532,6 @@
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
@@ -677,6 +675,8 @@
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
|
||||
@@ -17,3 +17,7 @@ A Transformer model for image-like data from [Flux2](https://hf.co/black-forest-
|
||||
## Flux2Transformer2DModel
|
||||
|
||||
[[autodoc]] Flux2Transformer2DModel
|
||||
|
||||
## Flux2Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.transformers.transformer_flux2.Flux2Transformer2DModelOutput
|
||||
|
||||
@@ -21,29 +21,31 @@
|
||||
> [!TIP]
|
||||
> Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers) to learn how to explore the tradeoff between scheduler speed and quality, and see the [reuse components across pipelines](../../using-diffusers/loading#reuse-a-pipeline) section to learn how to efficiently load the same components into multiple pipelines.
|
||||
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
## Basic usage
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
from diffusers import Cosmos2_5_PredictBasePipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
model_id = "nvidia/Cosmos-Predict2.5-2B"
|
||||
pipe = Cosmos2_5_PredictBasePipeline.from_pretrained(
|
||||
model_id, revision="diffusers/base/post-trained", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
prompt = "As the red light shifts to green, the red bus at the intersection begins to move forward, its headlights cutting through the falling snow. The snowy tire tracks deepen as the vehicle inches ahead, casting fresh lines onto the slushy road. Around it, streetlights glow warmer, illuminating the drifting flakes and wet reflections on the asphalt. Other cars behind start to edge forward, their beams joining the scene. The stillness of the urban street transitions into motion as the quiet snowfall is punctuated by the slow advance of traffic through the frosty city corridor."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
image=None,
|
||||
video=None,
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=93,
|
||||
generator=torch.Generator().manual_seed(1),
|
||||
).frames[0]
|
||||
export_to_video(output, "text2world.mp4", fps=16)
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
@@ -41,5 +41,11 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
|
||||
## Flux2KleinPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Flux2KleinKVPipeline
|
||||
|
||||
[[autodoc]] Flux2KleinKVPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -44,6 +44,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [ControlNet with Stable Diffusion XL](controlnet_sdxl) | text2image |
|
||||
| [ControlNet-XS](controlnetxs) | text2image |
|
||||
| [ControlNet-XS with Stable Diffusion XL](controlnetxs_sdxl) | text2image |
|
||||
| [Cosmos](cosmos) | text2video, video2video |
|
||||
| [Dance Diffusion](dance_diffusion) | unconditional audio generation |
|
||||
| [DDIM](ddim) | unconditional image generation |
|
||||
| [DDPM](ddpm) | unconditional image generation |
|
||||
|
||||
@@ -565,4 +565,16 @@ $ git push --set-upstream origin your-branch-for-syncing
|
||||
|
||||
### Style guide
|
||||
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).
|
||||
|
||||
|
||||
## Coding with AI agents
|
||||
|
||||
The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.
|
||||
|
||||
- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`)
|
||||
- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks
|
||||
- Setup commands:
|
||||
- `make codex` — symlink for OpenAI Codex
|
||||
- `make claude` — symlink for Claude Code
|
||||
- `make clean-ai` — remove generated symlinks
|
||||
@@ -7,7 +7,7 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Processor
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
@@ -17,7 +17,7 @@ from diffusers import (
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -44,6 +44,12 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
**LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT,
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
@@ -72,6 +78,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
|
||||
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
|
||||
# Decoder extra blocks
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
@@ -84,10 +97,34 @@ LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
LTX_2_3_VOCODER_RENAME_DICT = {
|
||||
# Handle upsamplers ("ups" --> "upsamplers") due to name clash
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
"act_post": "act_out",
|
||||
"downsample.lowpass": "downsample",
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# LTX-2.3 uses per-modality embedding projections
|
||||
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
|
||||
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
@@ -129,23 +166,24 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if ".ups." in key:
|
||||
new_key = key.replace(".ups.", ".upsamplers.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
@@ -155,13 +193,19 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = {
|
||||
".ups.": convert_ltx2_3_vocoder_upsamplers,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"text_embedding_projection",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
@@ -225,7 +269,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
@@ -238,6 +282,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": False,
|
||||
"cross_attn_mod": False,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
@@ -249,6 +295,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": False,
|
||||
"audio_cross_attn_mod": False,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
@@ -263,10 +311,62 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": True,
|
||||
"perturbed_attn": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": True,
|
||||
"cross_attn_mod": True,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": True,
|
||||
"audio_cross_attn_mod": True,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": False,
|
||||
"perturbed_attn": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -293,7 +393,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
@@ -301,20 +401,52 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": False,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": False,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": False,
|
||||
"proj_bias": False,
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 32,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 8,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": True,
|
||||
"audio_connector_num_attention_heads": 32,
|
||||
"audio_connector_attention_head_dim": 64,
|
||||
"audio_connector_num_layers": 8,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": True,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": True,
|
||||
"video_hidden_dim": 4096,
|
||||
"audio_hidden_dim": 2048,
|
||||
"proj_bias": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
@@ -416,7 +548,7 @@ def get_ltx2_video_vae_config(
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -435,6 +567,7 @@ def get_ltx2_video_vae_config(
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
@@ -451,6 +584,44 @@ def get_ltx2_video_vae_config(
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 1024),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 512, 1024),
|
||||
"layers_per_block": (4, 6, 4, 2, 2),
|
||||
"decoder_layers_per_block": (4, 6, 4, 2, 2),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (2, 2, 1, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -485,7 +656,7 @@ def convert_ltx2_video_vae(
|
||||
def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
@@ -508,6 +679,31 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
}, # Same config as LTX-2.0
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -540,7 +736,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) ->
|
||||
def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
@@ -549,21 +745,71 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "leaky_relu",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": False,
|
||||
"final_act_fn": "tanh",
|
||||
"final_bias": True,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1536,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [11, 4, 4, 4, 4, 4],
|
||||
"upsample_factors": [5, 2, 2, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "snakebeta",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": True,
|
||||
"antialias_ratio": 2,
|
||||
"antialias_kernel_size": 12,
|
||||
"final_act_fn": None,
|
||||
"final_bias": False,
|
||||
"bwe_in_channels": 128,
|
||||
"bwe_hidden_channels": 512,
|
||||
"bwe_out_channels": 2,
|
||||
"bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4],
|
||||
"bwe_upsample_factors": [6, 5, 2, 2, 2],
|
||||
"bwe_resnet_kernel_sizes": [3, 7, 11],
|
||||
"bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"bwe_act_fn": "snakebeta",
|
||||
"bwe_leaky_relu_negative_slope": 0.1,
|
||||
"bwe_antialias": True,
|
||||
"bwe_antialias_ratio": 2,
|
||||
"bwe_antialias_kernel_size": 12,
|
||||
"bwe_final_act_fn": None,
|
||||
"bwe_final_bias": False,
|
||||
"filter_length": 512,
|
||||
"hop_length": 80,
|
||||
"window_length": 512,
|
||||
"num_mel_channels": 64,
|
||||
"input_sampling_rate": 16000,
|
||||
"output_sampling_rate": 48000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
if version == "2.3":
|
||||
vocoder_cls = LTX2VocoderWithBWE
|
||||
else:
|
||||
vocoder_cls = LTX2Vocoder
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
vocoder = vocoder_cls.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -594,6 +840,18 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": True,
|
||||
}
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
@@ -651,13 +909,17 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
model_state_dict[param_name.removeprefix(prefix)] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
connector_prefixes = ["text_embedding_projection"]
|
||||
for param_name, param in combined_ckpt.items():
|
||||
for prefix in connector_prefixes:
|
||||
if param_name.startswith(prefix):
|
||||
# Check to make sure we're not overwriting an existing key
|
||||
if param_name not in model_state_dict:
|
||||
model_state_dict[param_name] = combined_ckpt[param_name]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
@@ -686,7 +948,7 @@ def get_args():
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
choices=["test", "2.0", "2.3"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
@@ -748,6 +1010,11 @@ def get_args():
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_processor",
|
||||
action="store_true",
|
||||
help="Whether to add a Gemma3Processor to the pipeline for prompt enhancement.",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
@@ -756,6 +1023,12 @@ def get_args():
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument(
|
||||
"--upsample_output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path where converted upsampling pipeline should be saved",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -787,7 +1060,7 @@ def main(args):
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.connectors,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
@@ -852,7 +1125,12 @@ def main(args):
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.add_processor:
|
||||
processor = Gemma3Processor.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
processor.save_pretrained(os.path.join(args.output_path, "processor"))
|
||||
|
||||
if args.latent_upsampler or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
@@ -866,14 +1144,26 @@ def main(args):
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
is_distilled_ckpt = "distilled" in args.combined_filename
|
||||
if is_distilled_ckpt:
|
||||
# Disable dynamic shifting and terminal shift so that distilled sigmas are used as-is
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=False,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=None,
|
||||
)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
@@ -891,10 +1181,12 @@ def main(args):
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
# As two diffusers pipelines cannot be in the same directory, save the upsampling pipeline to its own directory
|
||||
if args.upsample_output_path:
|
||||
upsample_output_path = args.upsample_output_path
|
||||
else:
|
||||
upsample_output_path = args.output_path
|
||||
pipe.save_pretrained(upsample_output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -510,6 +510,7 @@ else:
|
||||
"EasyAnimateControlPipeline",
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimatePipeline",
|
||||
"Flux2KleinKVPipeline",
|
||||
"Flux2KleinPipeline",
|
||||
"Flux2Pipeline",
|
||||
"FluxControlImg2ImgPipeline",
|
||||
@@ -1266,6 +1267,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateControlPipeline,
|
||||
EasyAnimateInpaintPipeline,
|
||||
EasyAnimatePipeline,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2KleinPipeline,
|
||||
Flux2Pipeline,
|
||||
FluxControlImg2ImgPipeline,
|
||||
|
||||
@@ -2156,6 +2156,9 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
# LTX-2.3
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
@@ -2538,8 +2541,12 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
alpha_tensor = state_dict.pop(alpha_key, None)
|
||||
if alpha_tensor is None:
|
||||
return 1.0, 1.0
|
||||
scale = (
|
||||
alpha_tensor.item() / rank
|
||||
) # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
|
||||
@@ -60,6 +60,16 @@ class ContextParallelConfig:
|
||||
rotate_method (`str`, *optional*, defaults to `"allgather"`):
|
||||
Method to use for rotating key/value states across devices in ring attention. Currently, only `"allgather"`
|
||||
is supported.
|
||||
ulysses_anything (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable "Ulysses Anything" mode, which supports arbitrary sequence lengths and head counts that
|
||||
are not evenly divisible by `ulysses_degree`. When enabled, `ulysses_degree` must be greater than 1 and
|
||||
`ring_degree` must be 1.
|
||||
mesh (`torch.distributed.device_mesh.DeviceMesh`, *optional*):
|
||||
A custom device mesh to use for context parallelism. If provided, this mesh will be used instead of
|
||||
creating a new one. This is useful when combining context parallelism with other parallelism strategies
|
||||
(e.g., FSDP, tensor parallelism) that share the same device mesh. The mesh must have both "ring" and
|
||||
"ulysses" dimensions. Use size 1 for dimensions not being used (e.g., `mesh_shape=(2, 1, 4)` with
|
||||
`mesh_dim_names=("ring", "ulysses", "fsdp")` for ring attention only with FSDP).
|
||||
|
||||
"""
|
||||
|
||||
@@ -68,6 +78,7 @@ class ContextParallelConfig:
|
||||
convert_to_fp32: bool = True
|
||||
# TODO: support alltoall
|
||||
rotate_method: Literal["allgather", "alltoall"] = "allgather"
|
||||
mesh: torch.distributed.device_mesh.DeviceMesh | None = None
|
||||
# Whether to enable ulysses anything attention to support
|
||||
# any sequence lengths and any head numbers.
|
||||
ulysses_anything: bool = False
|
||||
@@ -124,7 +135,7 @@ class ContextParallelConfig:
|
||||
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
self._flattened_mesh = self._mesh._flatten()
|
||||
self._flattened_mesh = self._mesh["ring", "ulysses"]._flatten()
|
||||
self._ring_mesh = self._mesh["ring"]
|
||||
self._ulysses_mesh = self._mesh["ulysses"]
|
||||
self._ring_local_rank = self._ring_mesh.get_local_rank()
|
||||
|
||||
@@ -237,7 +237,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
class LTX2VideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -285,10 +285,11 @@ class LTXVideoDownsampler3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
class LTX2VideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int | None = None,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
@@ -300,7 +301,8 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
out_channels = out_channels or in_channels
|
||||
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
|
||||
self.conv = LTX2VideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -408,7 +410,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
@@ -417,7 +419,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
@@ -426,7 +428,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
@@ -580,6 +582,7 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
upsample_type: str = "spatiotemporal",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
@@ -609,17 +612,38 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
self.upsamplers = nn.ModuleList()
|
||||
|
||||
if upsample_type == "spatial":
|
||||
self.upsamplers.append(
|
||||
LTX2VideoUpsampler3d(
|
||||
in_channels=out_channels * upscale_factor,
|
||||
stride=(1, 2, 2),
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
)
|
||||
elif upsample_type == "temporal":
|
||||
self.upsamplers.append(
|
||||
LTX2VideoUpsampler3d(
|
||||
in_channels=out_channels * upscale_factor,
|
||||
stride=(2, 1, 1),
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
)
|
||||
elif upsample_type == "spatiotemporal":
|
||||
self.upsamplers.append(
|
||||
LTX2VideoUpsampler3d(
|
||||
in_channels=out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
resnets = []
|
||||
for _ in range(num_layers):
|
||||
@@ -716,7 +740,7 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
@@ -726,6 +750,9 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -860,19 +887,27 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: tuple[bool, ...] = (False, False, False),
|
||||
inject_noise: bool | tuple[bool, ...] = (False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[bool, ...] = (2, 2, 2),
|
||||
spatial_padding_mode: str = "reflect",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_decoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(inject_noise, bool):
|
||||
inject_noise = (inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -917,6 +952,7 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
upsample_type=upsample_type[i],
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
@@ -1058,11 +1094,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[int, ...] = (2, 2, 2),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
@@ -1077,6 +1114,16 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
num_decoder_blocks = len(decoder_layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
if isinstance(decoder_spatio_temporal_scaling, bool):
|
||||
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(decoder_inject_noise, bool):
|
||||
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.encoder = LTX2VideoEncoder3d(
|
||||
in_channels=in_channels,
|
||||
@@ -1098,6 +1145,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
upsample_type=upsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
|
||||
@@ -1567,7 +1567,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
mesh = None
|
||||
if config.context_parallel_config is not None:
|
||||
cp_config = config.context_parallel_config
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
mesh = cp_config.mesh or torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=cp_config.mesh_shape,
|
||||
mesh_dim_names=cp_config.mesh_dim_names,
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -21,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils import BaseOutput, apply_lora_scale, logging
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -32,7 +33,6 @@ from ..embeddings import (
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import AdaLayerNormContinuous
|
||||
|
||||
@@ -40,6 +40,216 @@ from ..normalization import AdaLayerNormContinuous
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
@dataclass
|
||||
class Flux2Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
The output of [`Flux2Transformer2DModel`].
|
||||
|
||||
Args:
|
||||
sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
|
||||
The hidden states output conditioned on the `encoder_hidden_states` input.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
The populated KV cache for reference image tokens. Only returned when `kv_cache_mode="extract"`.
|
||||
"""
|
||||
|
||||
sample: "torch.Tensor" # noqa: F821
|
||||
kv_cache: "Flux2KVCache | None" = None
|
||||
|
||||
|
||||
class Flux2KVLayerCache:
|
||||
"""Per-layer KV cache for reference image tokens in the Flux2 Klein KV model.
|
||||
|
||||
Stores the K and V projections (post-RoPE) for reference tokens extracted during the first denoising step. Tensor
|
||||
format: (batch_size, num_ref_tokens, num_heads, head_dim).
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.k_ref: torch.Tensor | None = None
|
||||
self.v_ref: torch.Tensor | None = None
|
||||
|
||||
def store(self, k_ref: torch.Tensor, v_ref: torch.Tensor):
|
||||
"""Store reference token K/V."""
|
||||
self.k_ref = k_ref
|
||||
self.v_ref = v_ref
|
||||
|
||||
def get(self) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Retrieve cached reference token K/V."""
|
||||
if self.k_ref is None:
|
||||
raise RuntimeError("KV cache has not been populated yet.")
|
||||
return self.k_ref, self.v_ref
|
||||
|
||||
def clear(self):
|
||||
self.k_ref = None
|
||||
self.v_ref = None
|
||||
|
||||
|
||||
class Flux2KVCache:
|
||||
"""Container for all layers' reference-token KV caches.
|
||||
|
||||
Holds separate cache lists for double-stream and single-stream transformer blocks.
|
||||
"""
|
||||
|
||||
def __init__(self, num_double_layers: int, num_single_layers: int):
|
||||
self.double_block_caches = [Flux2KVLayerCache() for _ in range(num_double_layers)]
|
||||
self.single_block_caches = [Flux2KVLayerCache() for _ in range(num_single_layers)]
|
||||
self.num_ref_tokens: int = 0
|
||||
|
||||
def get_double(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.double_block_caches[layer_idx]
|
||||
|
||||
def get_single(self, layer_idx: int) -> Flux2KVLayerCache:
|
||||
return self.single_block_caches[layer_idx]
|
||||
|
||||
def clear(self):
|
||||
for cache in self.double_block_caches:
|
||||
cache.clear()
|
||||
for cache in self.single_block_caches:
|
||||
cache.clear()
|
||||
self.num_ref_tokens = 0
|
||||
|
||||
|
||||
def _flux2_kv_causal_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
num_txt_tokens: int,
|
||||
num_ref_tokens: int,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
backend=None,
|
||||
) -> torch.Tensor:
|
||||
"""Causal attention for KV caching where reference tokens only self-attend.
|
||||
|
||||
All tensors use the diffusers convention: (batch_size, seq_len, num_heads, head_dim).
|
||||
|
||||
Without cache (extract mode): sequence layout is [txt, ref, img]. txt+img tokens attend to all tokens, ref tokens
|
||||
only attend to themselves. With cache (cached mode): sequence layout is [txt, img]. Cached ref K/V are injected
|
||||
between txt and img.
|
||||
"""
|
||||
# No ref tokens and no cache — standard full attention
|
||||
if num_ref_tokens == 0 and kv_cache is None:
|
||||
return dispatch_attention_fn(query, key, value, backend=backend)
|
||||
|
||||
if kv_cache is not None:
|
||||
# Cached mode: inject ref K/V between txt and img
|
||||
k_ref, v_ref = kv_cache.get()
|
||||
|
||||
k_all = torch.cat([key[:, :num_txt_tokens], k_ref, key[:, num_txt_tokens:]], dim=1)
|
||||
v_all = torch.cat([value[:, :num_txt_tokens], v_ref, value[:, num_txt_tokens:]], dim=1)
|
||||
|
||||
return dispatch_attention_fn(query, k_all, v_all, backend=backend)
|
||||
|
||||
# Extract mode: ref tokens self-attend, txt+img attend to all
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
|
||||
q_txt = query[:, :ref_start]
|
||||
q_ref = query[:, ref_start:ref_end]
|
||||
q_img = query[:, ref_end:]
|
||||
|
||||
k_txt = key[:, :ref_start]
|
||||
k_ref = key[:, ref_start:ref_end]
|
||||
k_img = key[:, ref_end:]
|
||||
|
||||
v_txt = value[:, :ref_start]
|
||||
v_ref = value[:, ref_start:ref_end]
|
||||
v_img = value[:, ref_end:]
|
||||
|
||||
# txt+img attend to all tokens
|
||||
q_txt_img = torch.cat([q_txt, q_img], dim=1)
|
||||
k_all = torch.cat([k_txt, k_ref, k_img], dim=1)
|
||||
v_all = torch.cat([v_txt, v_ref, v_img], dim=1)
|
||||
attn_txt_img = dispatch_attention_fn(q_txt_img, k_all, v_all, backend=backend)
|
||||
attn_txt = attn_txt_img[:, :ref_start]
|
||||
attn_img = attn_txt_img[:, ref_start:]
|
||||
|
||||
# ref tokens self-attend only
|
||||
attn_ref = dispatch_attention_fn(q_ref, k_ref, v_ref, backend=backend)
|
||||
|
||||
return torch.cat([attn_txt, attn_ref, attn_img], dim=1)
|
||||
|
||||
|
||||
def _blend_mod_params(
|
||||
img_params: tuple[torch.Tensor, ...],
|
||||
ref_params: tuple[torch.Tensor, ...],
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
"""Blend modulation parameters so that the first `num_ref` positions use `ref_params`."""
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[rm.expand(B, num_ref, -1), im.expand(B, seq_len, -1)[:, num_ref:, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return tuple(blended)
|
||||
|
||||
|
||||
def _blend_double_block_mods(
|
||||
img_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend double-block image-stream modulations for a [ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors (before `Flux2Modulation.split`) and returns a blended raw tensor that is compatible
|
||||
with `Flux2Modulation.split(mod, 2)`.
|
||||
"""
|
||||
if img_mod.ndim == 2:
|
||||
img_mod = img_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_chunks = torch.chunk(img_mod, 6, dim=-1)
|
||||
ref_chunks = torch.chunk(ref_mod, 6, dim=-1)
|
||||
img_mods = (img_chunks[0:3], img_chunks[3:6])
|
||||
ref_mods = (ref_chunks[0:3], ref_chunks[3:6])
|
||||
|
||||
all_params = []
|
||||
for img_set, ref_set in zip(img_mods, ref_mods):
|
||||
blended = _blend_mod_params(img_set, ref_set, num_ref, seq_len)
|
||||
all_params.extend(blended)
|
||||
return torch.cat(all_params, dim=-1)
|
||||
|
||||
|
||||
def _blend_single_block_mods(
|
||||
single_mod: torch.Tensor,
|
||||
ref_mod: torch.Tensor,
|
||||
num_txt: int,
|
||||
num_ref: int,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
"""Blend single-block modulations for a [txt, ref, img] sequence layout.
|
||||
|
||||
Takes raw modulation tensors and returns a blended raw tensor compatible with `Flux2Modulation.split(mod, 1)`.
|
||||
"""
|
||||
if single_mod.ndim == 2:
|
||||
single_mod = single_mod.unsqueeze(1)
|
||||
ref_mod = ref_mod.unsqueeze(1)
|
||||
img_params = torch.chunk(single_mod, 3, dim=-1)
|
||||
ref_params = torch.chunk(ref_mod, 3, dim=-1)
|
||||
|
||||
blended = []
|
||||
for im, rm in zip(img_params, ref_params):
|
||||
if im.ndim == 2:
|
||||
im = im.unsqueeze(1)
|
||||
rm = rm.unsqueeze(1)
|
||||
B = im.shape[0]
|
||||
im_expanded = im.expand(B, seq_len, -1)
|
||||
rm_expanded = rm.expand(B, num_ref, -1)
|
||||
blended.append(
|
||||
torch.cat(
|
||||
[im_expanded[:, :num_txt, :], rm_expanded, im_expanded[:, num_txt + num_ref :, :]],
|
||||
dim=1,
|
||||
)
|
||||
)
|
||||
return torch.cat(blended, dim=-1)
|
||||
|
||||
|
||||
def _get_projections(attn: "Flux2Attention", hidden_states, encoder_hidden_states=None):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
@@ -181,9 +391,108 @@ class Flux2AttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 double-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored in the cache after RoPE and causal attention is
|
||||
used (ref tokens self-attend only, txt+img attend to all). When `kv_cache_mode` is "cached", cached ref K/V are
|
||||
injected during attention. When no KV args are provided, behaves identically to `Flux2AttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
query, key, value, encoder_query, encoder_key, encoder_value = _get_qkv_projections(
|
||||
attn, hidden_states, encoder_hidden_states
|
||||
)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if attn.added_kv_proj_dim is not None:
|
||||
encoder_query = encoder_query.unflatten(-1, (attn.heads, -1))
|
||||
encoder_key = encoder_key.unflatten(-1, (attn.heads, -1))
|
||||
encoder_value = encoder_value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
encoder_query = attn.norm_added_q(encoder_query)
|
||||
encoder_key = attn.norm_added_k(encoder_key)
|
||||
|
||||
query = torch.cat([encoder_query, query], dim=1)
|
||||
key = torch.cat([encoder_key, key], dim=1)
|
||||
value = torch.cat([encoder_value, value], dim=1)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1] if encoder_hidden_states is not None else 0
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
hidden_states = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
encoder_hidden_states, hidden_states = hidden_states.split_with_sizes(
|
||||
[encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1
|
||||
)
|
||||
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
return hidden_states, encoder_hidden_states
|
||||
else:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = Flux2AttnProcessor
|
||||
_available_processors = [Flux2AttnProcessor]
|
||||
_available_processors = [Flux2AttnProcessor, Flux2KVAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -312,6 +621,90 @@ class Flux2ParallelSelfAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2KVParallelSelfAttnProcessor:
|
||||
"""
|
||||
Attention processor for Flux2 single-stream blocks with KV caching support for reference image tokens.
|
||||
|
||||
When `kv_cache_mode` is "extract", reference token K/V are stored and causal attention is used. When
|
||||
`kv_cache_mode` is "cached", cached ref K/V are injected during attention. When no KV args are provided, behaves
|
||||
identically to `Flux2ParallelSelfAttnProcessor`.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(f"{self.__class__.__name__} requires PyTorch 2.0. Please upgrade your pytorch version.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "Flux2ParallelSelfAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
image_rotary_emb: torch.Tensor | None = None,
|
||||
kv_cache: Flux2KVLayerCache | None = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_txt_tokens: int = 0,
|
||||
num_ref_tokens: int = 0,
|
||||
) -> torch.Tensor:
|
||||
# Parallel in (QKV + MLP in) projection
|
||||
hidden_states_proj = attn.to_qkv_mlp_proj(hidden_states)
|
||||
qkv, mlp_hidden_states = torch.split(
|
||||
hidden_states_proj, [3 * attn.inner_dim, attn.mlp_hidden_dim * attn.mlp_mult_factor], dim=-1
|
||||
)
|
||||
|
||||
query, key, value = qkv.chunk(3, dim=-1)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if image_rotary_emb is not None:
|
||||
query = apply_rotary_emb(query, image_rotary_emb, sequence_dim=1)
|
||||
key = apply_rotary_emb(key, image_rotary_emb, sequence_dim=1)
|
||||
|
||||
# Extract ref K/V from the combined sequence
|
||||
if kv_cache_mode == "extract" and kv_cache is not None and num_ref_tokens > 0:
|
||||
ref_start = num_txt_tokens
|
||||
ref_end = num_txt_tokens + num_ref_tokens
|
||||
kv_cache.store(key[:, ref_start:ref_end].clone(), value[:, ref_start:ref_end].clone())
|
||||
|
||||
# Dispatch attention
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, num_ref_tokens, backend=self._attention_backend
|
||||
)
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
attn_output = _flux2_kv_causal_attention(
|
||||
query, key, value, num_txt_tokens, 0, kv_cache=kv_cache, backend=self._attention_backend
|
||||
)
|
||||
else:
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
attn_output = attn_output.flatten(2, 3)
|
||||
attn_output = attn_output.to(query.dtype)
|
||||
|
||||
# Handle the feedforward (FF) logic
|
||||
mlp_hidden_states = attn.mlp_act_fn(mlp_hidden_states)
|
||||
|
||||
# Concatenate and parallel output projection
|
||||
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=-1)
|
||||
hidden_states = attn.to_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
Flux 2 parallel self-attention for the Flux 2 single-stream transformer blocks.
|
||||
@@ -322,7 +715,7 @@ class Flux2ParallelSelfAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = Flux2ParallelSelfAttnProcessor
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor]
|
||||
_available_processors = [Flux2ParallelSelfAttnProcessor, Flux2KVParallelSelfAttnProcessor]
|
||||
# Does not support QKV fusion as the QKV projections are always fused
|
||||
_supports_qkv_fusion = False
|
||||
|
||||
@@ -780,6 +1173,8 @@ class Flux2Transformer2DModel(
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
_skip_keys = ["kv_cache"]
|
||||
|
||||
@apply_lora_scale("joint_attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
@@ -791,19 +1186,21 @@ class Flux2Transformer2DModel(
|
||||
guidance: torch.Tensor = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor | Transformer2DModelOutput:
|
||||
kv_cache: "Flux2KVCache | None" = None,
|
||||
kv_cache_mode: str | None = None,
|
||||
num_ref_tokens: int = 0,
|
||||
ref_fixed_timestep: float = 0.0,
|
||||
) -> torch.Tensor | Flux2Transformer2DModelOutput:
|
||||
"""
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
The [`Flux2Transformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
||||
Input `hidden_states`.
|
||||
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
timestep ( `torch.LongTensor`):
|
||||
timestep (`torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
@@ -811,13 +1208,23 @@ class Flux2Transformer2DModel(
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
kv_cache (`Flux2KVCache`, *optional*):
|
||||
KV cache for reference image tokens. When `kv_cache_mode` is "extract", a new cache is created and
|
||||
returned. When "cached", the provided cache is used to inject ref K/V during attention.
|
||||
kv_cache_mode (`str`, *optional*):
|
||||
One of "extract" (first step with ref tokens) or "cached" (subsequent steps using cached ref K/V). When
|
||||
`None`, standard forward pass without KV caching.
|
||||
num_ref_tokens (`int`, defaults to `0`):
|
||||
Number of reference image tokens prepended to `hidden_states` (only used when
|
||||
`kv_cache_mode="extract"`).
|
||||
ref_fixed_timestep (`float`, defaults to `0.0`):
|
||||
Fixed timestep for reference token modulation (only used when `kv_cache_mode="extract"`).
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
`tuple` where the first element is the sample tensor. When `kv_cache_mode="extract"`, also returns the
|
||||
populated `Flux2KVCache`.
|
||||
"""
|
||||
# 0. Handle input arguments
|
||||
|
||||
num_txt_tokens = encoder_hidden_states.shape[1]
|
||||
|
||||
# 1. Calculate timestep embedding and modulation parameters
|
||||
@@ -832,13 +1239,33 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
|
||||
# KV extract mode: create cache and blend modulations for ref tokens
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
num_img_tokens = hidden_states.shape[1] # includes ref tokens
|
||||
|
||||
kv_cache = Flux2KVCache(
|
||||
num_double_layers=len(self.transformer_blocks),
|
||||
num_single_layers=len(self.single_transformer_blocks),
|
||||
)
|
||||
kv_cache.num_ref_tokens = num_ref_tokens
|
||||
|
||||
# Ref tokens use a fixed timestep for modulation
|
||||
ref_timestep = torch.full_like(timestep, ref_fixed_timestep * 1000)
|
||||
ref_temb = self.time_guidance_embed(ref_timestep, guidance)
|
||||
|
||||
ref_double_mod_img = self.double_stream_modulation_img(ref_temb)
|
||||
ref_single_mod = self.single_stream_modulation(ref_temb)
|
||||
|
||||
# Blend double block img modulation: [ref_mod, img_mod]
|
||||
double_stream_mod_img = _blend_double_block_mods(
|
||||
double_stream_mod_img, ref_double_mod_img, num_ref_tokens, num_img_tokens
|
||||
)
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
# 3. Calculate RoPE embeddings from image and text tokens
|
||||
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
|
||||
# text prompts of differents lengths. Is this a use case we want to support?
|
||||
if img_ids.ndim == 3:
|
||||
img_ids = img_ids[0]
|
||||
if txt_ids.ndim == 3:
|
||||
@@ -851,8 +1278,29 @@ class Flux2Transformer2DModel(
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
)
|
||||
|
||||
# 4. Double Stream Transformer Blocks
|
||||
# 4. Build joint_attention_kwargs with KV cache info
|
||||
if kv_cache_mode == "extract":
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "extract",
|
||||
"num_ref_tokens": num_ref_tokens,
|
||||
}
|
||||
elif kv_cache_mode == "cached" and kv_cache is not None:
|
||||
kv_attn_kwargs = {
|
||||
**(joint_attention_kwargs or {}),
|
||||
"kv_cache": None,
|
||||
"kv_cache_mode": "cached",
|
||||
"num_ref_tokens": kv_cache.num_ref_tokens,
|
||||
}
|
||||
else:
|
||||
kv_attn_kwargs = joint_attention_kwargs
|
||||
|
||||
# 5. Double Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs["kv_cache"] = kv_cache.get_double(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -861,7 +1309,7 @@ class Flux2Transformer2DModel(
|
||||
double_stream_mod_img,
|
||||
double_stream_mod_txt,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs,
|
||||
)
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
@@ -870,13 +1318,30 @@ class Flux2Transformer2DModel(
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs,
|
||||
)
|
||||
|
||||
# Concatenate text and image streams for single-block inference
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
# 5. Single Stream Transformer Blocks
|
||||
# Blend single block modulation for extract mode: [txt_mod, ref_mod, img_mod]
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
total_single_len = hidden_states.shape[1]
|
||||
single_stream_mod = _blend_single_block_mods(
|
||||
single_stream_mod, ref_single_mod, num_txt_tokens, num_ref_tokens, total_single_len
|
||||
)
|
||||
|
||||
# Build single-block KV kwargs (single blocks need num_txt_tokens)
|
||||
if kv_cache_mode is not None:
|
||||
kv_attn_kwargs_single = {**kv_attn_kwargs, "num_txt_tokens": num_txt_tokens}
|
||||
else:
|
||||
kv_attn_kwargs_single = kv_attn_kwargs
|
||||
|
||||
# 6. Single Stream Transformer Blocks
|
||||
for index_block, block in enumerate(self.single_transformer_blocks):
|
||||
if kv_cache_mode is not None and kv_cache is not None:
|
||||
kv_attn_kwargs_single["kv_cache"] = kv_cache.get_single(index_block)
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -884,7 +1349,7 @@ class Flux2Transformer2DModel(
|
||||
None,
|
||||
single_stream_mod,
|
||||
concat_rotary_emb,
|
||||
joint_attention_kwargs,
|
||||
kv_attn_kwargs_single,
|
||||
)
|
||||
else:
|
||||
hidden_states = block(
|
||||
@@ -892,16 +1357,25 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states=None,
|
||||
temb_mod=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
joint_attention_kwargs=kv_attn_kwargs_single,
|
||||
)
|
||||
# Remove text tokens from concatenated stream
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 6. Output layers
|
||||
# Remove text tokens (and ref tokens in extract mode) from concatenated stream
|
||||
if kv_cache_mode == "extract" and num_ref_tokens > 0:
|
||||
hidden_states = hidden_states[:, num_txt_tokens + num_ref_tokens :, ...]
|
||||
else:
|
||||
hidden_states = hidden_states[:, num_txt_tokens:, ...]
|
||||
|
||||
# 7. Output layers
|
||||
hidden_states = self.norm_out(hidden_states, temb)
|
||||
output = self.proj_out(hidden_states)
|
||||
|
||||
if kv_cache_mode == "extract":
|
||||
if not return_dict:
|
||||
return (output, kv_cache)
|
||||
return Flux2Transformer2DModelOutput(sample=output, kv_cache=kv_cache)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
return Flux2Transformer2DModelOutput(sample=output)
|
||||
|
||||
@@ -178,6 +178,10 @@ class LTX2AudioVideoAttnProcessor:
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -212,6 +216,112 @@ class LTX2AudioVideoAttnProcessor:
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2PerturbedAttnProcessor:
|
||||
r"""
|
||||
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
if all_perturbed is None:
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
|
||||
if all_perturbed:
|
||||
# Skip attention, use the value projection value
|
||||
hidden_states = value
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
if attn.rope_type == "interleaved":
|
||||
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_interleaved_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
elif attn.rope_type == "split":
|
||||
query = apply_split_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_split_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if perturbation_mask is not None:
|
||||
value = value.flatten(2, 3)
|
||||
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
@@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -240,6 +350,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -266,6 +377,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if apply_gated_attention:
|
||||
# Per head gate values
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
@@ -321,6 +438,10 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
audio_num_attention_heads: int,
|
||||
audio_attention_head_dim,
|
||||
audio_cross_attention_dim: int,
|
||||
video_gated_attn: bool = False,
|
||||
video_cross_attn_adaln: bool = False,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_adaln: bool = False,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = True,
|
||||
@@ -328,9 +449,16 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
elementwise_affine: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
perturbed_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.perturbed_attn = perturbed_attn
|
||||
if perturbed_attn:
|
||||
attn_processor_cls = LTX2PerturbedAttnProcessor
|
||||
else:
|
||||
attn_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
|
||||
# 1. Self-Attention (video and audio)
|
||||
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn1 = LTX2Attention(
|
||||
@@ -343,6 +471,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -356,6 +486,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 2. Prompt Cross-Attention
|
||||
@@ -370,6 +502,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -383,6 +517,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
@@ -398,6 +534,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
@@ -412,6 +550,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 4. Feedforward layers
|
||||
@@ -422,14 +562,36 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
|
||||
|
||||
# 5. Per-Layer Modulation Parameters
|
||||
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
|
||||
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
|
||||
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
|
||||
self.video_cross_attn_adaln = video_cross_attn_adaln
|
||||
self.audio_cross_attn_adaln = audio_cross_attn_adaln
|
||||
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
|
||||
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Prompt cross-attn (attn2) additional modulation params
|
||||
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
|
||||
if self.cross_attn_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
|
||||
|
||||
# Per-layer a2v, v2a Cross-Attention mod params
|
||||
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
|
||||
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
|
||||
|
||||
@staticmethod
|
||||
def get_mod_params(
|
||||
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.shape[1], num_ada_params, -1
|
||||
)
|
||||
ada_params = ada_values.unbind(dim=2)
|
||||
return ada_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -442,143 +604,181 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
temb_ca_audio_scale_shift: torch.Tensor,
|
||||
temb_ca_gate: torch.Tensor,
|
||||
temb_ca_audio_gate: torch.Tensor,
|
||||
temb_prompt: torch.Tensor | None = None,
|
||||
temb_prompt_audio: torch.Tensor | None = None,
|
||||
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
audio_self_attention_mask: torch.Tensor | None = None,
|
||||
a2v_cross_attention_mask: torch.Tensor | None = None,
|
||||
v2a_cross_attention_mask: torch.Tensor | None = None,
|
||||
use_a2v_cross_attention: bool = True,
|
||||
use_v2a_cross_attention: bool = True,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Video and Audio Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# 1.1. Video Self-Attention
|
||||
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
|
||||
if self.video_cross_attn_adaln:
|
||||
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=video_rotary_emb,
|
||||
)
|
||||
video_self_attn_args = {
|
||||
"hidden_states": norm_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": video_rotary_emb,
|
||||
"attention_mask": self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
video_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
video_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_hidden_states = self.attn1(**video_self_attn_args)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
|
||||
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
|
||||
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
|
||||
batch_size, temb_audio.size(1), num_audio_ada_params, -1
|
||||
)
|
||||
# 1.2. Audio Self-Attention
|
||||
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_values.unbind(dim=2)
|
||||
audio_ada_params[:6]
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(
|
||||
hidden_states=norm_audio_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=audio_rotary_emb,
|
||||
)
|
||||
audio_self_attn_args = {
|
||||
"hidden_states": norm_audio_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": audio_rotary_emb,
|
||||
"attention_mask": audio_self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
audio_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
audio_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args)
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
|
||||
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
|
||||
if self.cross_attn_adaln:
|
||||
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
|
||||
shift_text_kv, scale_text_kv = video_prompt_ada_params
|
||||
|
||||
audio_prompt_ada_params = self.get_mod_params(
|
||||
self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
|
||||
)
|
||||
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
|
||||
|
||||
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Test)
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.video_cross_attn_adaln:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.video_cross_attn_adaln:
|
||||
attn_hidden_states = attn_hidden_states * gate_text_q
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
# 2.2. Audio-Text Cross-Attention
|
||||
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
|
||||
if self.audio_cross_attn_adaln:
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn2(
|
||||
norm_audio_hidden_states,
|
||||
encoder_hidden_states=audio_encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
if use_a2v_cross_attention or use_v2a_cross_attention:
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
|
||||
# Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# 3.1. Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
video_ca_scale_shift_table = (
|
||||
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
|
||||
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_gate = (
|
||||
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
|
||||
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
|
||||
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
|
||||
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
|
||||
a2v_gate = video_ca_gate[0].squeeze(2)
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
|
||||
a2v_gate = video_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
audio_ca_scale_shift_table = (
|
||||
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
|
||||
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_gate = (
|
||||
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
|
||||
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_ada_params = self.get_mod_params(
|
||||
audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
|
||||
)
|
||||
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
|
||||
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
|
||||
v2a_gate = audio_ca_gate[0].squeeze(2)
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
|
||||
v2a_gate = audio_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
if use_a2v_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_a2v_ca_scale.squeeze(2)
|
||||
) + video_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
|
||||
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
if use_v2a_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_v2a_ca_scale.squeeze(2)
|
||||
) + video_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
|
||||
# 4. Feedforward
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
@@ -918,6 +1118,8 @@ class LTX2VideoTransformer3DModel(
|
||||
pos_embed_max_pos: int = 20,
|
||||
base_height: int = 2048,
|
||||
base_width: int = 2048,
|
||||
gated_attn: bool = False,
|
||||
cross_attn_mod: bool = False,
|
||||
audio_in_channels: int = 128, # Audio Arguments
|
||||
audio_out_channels: int | None = 128,
|
||||
audio_patch_size: int = 1,
|
||||
@@ -929,6 +1131,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_pos_embed_max_pos: int = 20,
|
||||
audio_sampling_rate: int = 16000,
|
||||
audio_hop_length: int = 160,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_mod: bool = False,
|
||||
num_layers: int = 48, # Shared arguments
|
||||
activation_fn: str = "gelu-approximate",
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
@@ -943,6 +1147,8 @@ class LTX2VideoTransformer3DModel(
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: str = "interleaved",
|
||||
use_prompt_embeddings=True,
|
||||
perturbed_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -956,17 +1162,25 @@ class LTX2VideoTransformer3DModel(
|
||||
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
|
||||
|
||||
# 2. Prompt embeddings
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
if use_prompt_embeddings:
|
||||
# LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
|
||||
# 3. Timestep Modulation Params and Embedding
|
||||
self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3
|
||||
|
||||
# 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
|
||||
# time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
|
||||
self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
|
||||
video_time_emb_mod_params = 9 if cross_attn_mod else 6
|
||||
audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
|
||||
self.time_embed = LTX2AdaLayerNormSingle(
|
||||
inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False
|
||||
)
|
||||
self.audio_time_embed = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=6, use_additional_conditions=False
|
||||
audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 3.2. Global Cross Attention Modulation Parameters
|
||||
@@ -995,6 +1209,13 @@ class LTX2VideoTransformer3DModel(
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
|
||||
|
||||
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
|
||||
if self.prompt_modulation:
|
||||
self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
|
||||
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=2, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 4. Rotary Positional Embeddings (RoPE)
|
||||
# Self-Attention
|
||||
self.rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
@@ -1071,6 +1292,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_attention_heads=audio_num_attention_heads,
|
||||
audio_attention_head_dim=audio_attention_head_dim,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
video_gated_attn=gated_attn,
|
||||
video_cross_attn_adaln=cross_attn_mod,
|
||||
audio_gated_attn=audio_gated_attn,
|
||||
audio_cross_attn_adaln=audio_cross_attn_mod,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
@@ -1078,6 +1303,7 @@ class LTX2VideoTransformer3DModel(
|
||||
eps=norm_eps,
|
||||
elementwise_affine=norm_elementwise_affine,
|
||||
rope_type=rope_type,
|
||||
perturbed_attn=perturbed_attn,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -1101,8 +1327,12 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: torch.LongTensor | None = None,
|
||||
sigma: torch.Tensor | None = None,
|
||||
audio_sigma: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
audio_self_attention_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
@@ -1110,6 +1340,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_frames: int | None = None,
|
||||
video_coords: torch.Tensor | None = None,
|
||||
audio_coords: torch.Tensor | None = None,
|
||||
isolate_modalities: bool = False,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -1131,10 +1365,19 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_timestep (`torch.Tensor`, *optional*):
|
||||
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
|
||||
params. This is only used by certain pipelines such as the I2V pipeline.
|
||||
sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in
|
||||
models such as LTX-2.3.
|
||||
audio_sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in
|
||||
models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to
|
||||
the provided `sigma` value.
|
||||
encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
|
||||
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)` for audio modeling.
|
||||
self_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative self-attention mask of shape `(batch_size, seq_len, seq_len)`.
|
||||
num_frames (`int`, *optional*):
|
||||
The number of latent video frames. Used if calculating the video coordinates for RoPE.
|
||||
height (`int`, *optional*):
|
||||
@@ -1152,6 +1395,21 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_coords (`torch.Tensor`, *optional*):
|
||||
The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
|
||||
`(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
|
||||
isolate_modalities (`bool`, *optional*, defaults to `False`):
|
||||
Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention (for all blocks). Use for modality guidance in LTX-2.3.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the
|
||||
self-attention operations by simply using the values rather than the full scaled dot-product attention
|
||||
(SDPA) operation. If `None` or empty, STG will not be applied to any block.
|
||||
perturbation_mask (`torch.Tensor`, *optional*):
|
||||
Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch
|
||||
elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is
|
||||
not supplied, will default to applying STG (perturbing) all batch elements.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
attention_kwargs (`dict[str, Any]`, *optional*):
|
||||
Optional dict of keyword args to be passed to the attention processor.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1165,6 +1423,7 @@ class LTX2VideoTransformer3DModel(
|
||||
"""
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
audio_sigma = audio_sigma if audio_sigma is not None else sigma
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
@@ -1175,6 +1434,32 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
|
||||
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)
|
||||
|
||||
if self_attention_mask is not None and self_attention_mask.ndim == 3:
|
||||
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
|
||||
# number and positive values are mapped to their logarithm.
|
||||
dtype_finfo = torch.finfo(hidden_states.dtype)
|
||||
additive_self_attn_mask = torch.full_like(self_attention_mask, dtype_finfo.min, dtype=hidden_states.dtype)
|
||||
unmasked_entries = self_attention_mask > 0
|
||||
if torch.any(unmasked_entries):
|
||||
additive_self_attn_mask[unmasked_entries] = torch.log(
|
||||
self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
|
||||
).to(hidden_states.dtype)
|
||||
self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
|
||||
|
||||
if audio_self_attention_mask is not None and audio_self_attention_mask.ndim == 3:
|
||||
# Convert to additive attention mask in log-space where 0 (masked) values get mapped to a large negative
|
||||
# number and positive values are mapped to their logarithm.
|
||||
dtype_finfo = torch.finfo(audio_hidden_states.dtype)
|
||||
additive_self_attn_mask = torch.full_like(
|
||||
audio_self_attention_mask, dtype_finfo.min, dtype=audio_hidden_states.dtype
|
||||
)
|
||||
unmasked_entries = audio_self_attention_mask > 0
|
||||
if torch.any(unmasked_entries):
|
||||
additive_self_attn_mask[unmasked_entries] = torch.log(
|
||||
audio_self_attention_mask[unmasked_entries].clamp(min=dtype_finfo.tiny)
|
||||
).to(audio_hidden_states.dtype)
|
||||
audio_self_attention_mask = additive_self_attn_mask.unsqueeze(1) # [batch_size, 1, seq_len, seq_len]
|
||||
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Prepare RoPE positional embeddings
|
||||
@@ -1223,14 +1508,28 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
if self.prompt_modulation:
|
||||
# LTX-2.3
|
||||
temb_prompt, _ = self.prompt_adaln(
|
||||
sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
temb_prompt_audio, _ = self.audio_prompt_adaln(
|
||||
audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
|
||||
)
|
||||
temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
|
||||
temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
|
||||
else:
|
||||
temb_prompt = temb_prompt_audio = None
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten()
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
timestep.flatten(),
|
||||
video_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
@@ -1239,13 +1538,14 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten()
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
audio_timestep.flatten(),
|
||||
audio_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
@@ -1254,15 +1554,30 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
|
||||
|
||||
# 4. Prepare prompt embeddings
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
# 4. Prepare prompt embeddings (LTX-2.0)
|
||||
if self.config.use_prompt_embeddings:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(
|
||||
batch_size, -1, audio_hidden_states.size(-1)
|
||||
)
|
||||
|
||||
# 5. Run transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or []
|
||||
if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None:
|
||||
# If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements.
|
||||
perturbation_mask = torch.zeros((batch_size,))
|
||||
if perturbation_mask is not None and perturbation_mask.ndim == 1:
|
||||
perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
stg_blocks = set(spatio_temporal_guidance_blocks)
|
||||
|
||||
for block_idx, block in enumerate(self.transformer_blocks):
|
||||
block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None
|
||||
block_all_perturbed = all_perturbed if block_idx in stg_blocks else False
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1276,12 +1591,22 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_cross_attn_scale_shift,
|
||||
video_cross_attn_a2v_gate,
|
||||
audio_cross_attn_v2a_gate,
|
||||
temb_prompt,
|
||||
temb_prompt_audio,
|
||||
video_rotary_emb,
|
||||
audio_rotary_emb,
|
||||
video_cross_attn_rotary_emb,
|
||||
audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
audio_encoder_attention_mask,
|
||||
self_attention_mask,
|
||||
audio_self_attention_mask,
|
||||
None, # a2v_cross_attention_mask
|
||||
None, # v2a_cross_attention_mask
|
||||
not isolate_modalities, # use_a2v_cross_attention
|
||||
not isolate_modalities, # use_v2a_cross_attention
|
||||
block_perturbation_mask,
|
||||
block_all_perturbed,
|
||||
)
|
||||
else:
|
||||
hidden_states, audio_hidden_states = block(
|
||||
@@ -1295,12 +1620,22 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
|
||||
temb_ca_gate=video_cross_attn_a2v_gate,
|
||||
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
|
||||
temb_prompt=temb_prompt,
|
||||
temb_prompt_audio=temb_prompt_audio,
|
||||
video_rotary_emb=video_rotary_emb,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
ca_video_rotary_emb=video_cross_attn_rotary_emb,
|
||||
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
audio_encoder_attention_mask=audio_encoder_attention_mask,
|
||||
self_attention_mask=self_attention_mask,
|
||||
audio_self_attention_mask=audio_self_attention_mask,
|
||||
a2v_cross_attention_mask=None,
|
||||
v2a_cross_attention_mask=None,
|
||||
use_a2v_cross_attention=not isolate_modalities,
|
||||
use_v2a_cross_attention=not isolate_modalities,
|
||||
perturbation_mask=block_perturbation_mask,
|
||||
all_perturbed=block_all_perturbed,
|
||||
)
|
||||
|
||||
# 6. Output layers (including unpatchification)
|
||||
|
||||
@@ -129,7 +129,7 @@ else:
|
||||
]
|
||||
_import_structure["bria"] = ["BriaPipeline"]
|
||||
_import_structure["bria_fibo"] = ["BriaFiboPipeline", "BriaFiboEditPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline"]
|
||||
_import_structure["flux2"] = ["Flux2Pipeline", "Flux2KleinPipeline", "Flux2KleinKVPipeline"]
|
||||
_import_structure["flux"] = [
|
||||
"FluxControlPipeline",
|
||||
"FluxControlInpaintPipeline",
|
||||
@@ -671,7 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPriorReduxPipeline,
|
||||
ReduxImageEncoder,
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .flux2 import Flux2KleinKVPipeline, Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
|
||||
@@ -95,6 +95,7 @@ from .pag import (
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
@@ -185,6 +186,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("z-image-controlnet-inpaint", ZImageControlNetInpaintPipeline),
|
||||
("z-image-omni", ZImageOmniPipeline),
|
||||
("ovis", OvisImagePipeline),
|
||||
("prx", PRXPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -82,13 +82,16 @@ EXAMPLE_DOC_STRING = """
|
||||
```python
|
||||
>>> import cv2
|
||||
>>> import numpy as np
|
||||
>>> from PIL import Image
|
||||
>>> import torch
|
||||
>>> from diffusers import Cosmos2_5_TransferPipeline, AutoModel
|
||||
>>> from diffusers.utils import export_to_video, load_video
|
||||
|
||||
>>> model_id = "nvidia/Cosmos-Transfer2.5-2B"
|
||||
>>> # Load a Transfer2.5 controlnet variant (edge, depth, seg, or blur)
|
||||
>>> controlnet = AutoModel.from_pretrained(model_id, revision="diffusers/controlnet/general/edge")
|
||||
>>> controlnet = AutoModel.from_pretrained(
|
||||
... model_id, revision="diffusers/controlnet/general/edge", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe = Cosmos2_5_TransferPipeline.from_pretrained(
|
||||
... model_id, controlnet=controlnet, revision="diffusers/general", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
|
||||
@@ -24,6 +24,7 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipeline_flux2"] = ["Flux2Pipeline"]
|
||||
_import_structure["pipeline_flux2_klein"] = ["Flux2KleinPipeline"]
|
||||
_import_structure["pipeline_flux2_klein_kv"] = ["Flux2KleinKVPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .pipeline_flux2 import Flux2Pipeline
|
||||
from .pipeline_flux2_klein import Flux2KleinPipeline
|
||||
from .pipeline_flux2_klein_kv import Flux2KleinKVPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
|
||||
@@ -744,7 +744,7 @@ class Flux2Pipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image, PIL.Image.Image] | None = None,
|
||||
image: PIL.Image.Image | list[PIL.Image.Image] | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
|
||||
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
886
src/diffusers/pipelines/flux2/pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,886 @@
|
||||
# Copyright 2025 Black Forest Labs and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
from transformers import Qwen2TokenizerFast, Qwen3ForCausalLM
|
||||
|
||||
from ...loaders import Flux2LoraLoaderMixin
|
||||
from ...models import AutoencoderKLFlux2, Flux2Transformer2DModel
|
||||
from ...models.transformers.transformer_flux2 import Flux2KVAttnProcessor, Flux2KVParallelSelfAttnProcessor
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .image_processor import Flux2ImageProcessor
|
||||
from .pipeline_output import Flux2PipelineOutput
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from PIL import Image
|
||||
>>> from diffusers import Flux2KleinKVPipeline
|
||||
|
||||
>>> pipe = Flux2KleinKVPipeline.from_pretrained(
|
||||
... "black-forest-labs/FLUX.2-klein-9b-kv", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> pipe.to("cuda")
|
||||
>>> ref_image = Image.open("reference.png")
|
||||
>>> image = pipe("A cat dressed like a wizard", image=ref_image, num_inference_steps=4).images[0]
|
||||
>>> image.save("flux2_kv_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.compute_empirical_mu
|
||||
def compute_empirical_mu(image_seq_len: int, num_steps: int) -> float:
|
||||
a1, b1 = 8.73809524e-05, 1.89833333
|
||||
a2, b2 = 0.00016927, 0.45666666
|
||||
|
||||
if image_seq_len > 4300:
|
||||
mu = a2 * image_seq_len + b2
|
||||
return float(mu)
|
||||
|
||||
m_200 = a2 * image_seq_len + b2
|
||||
m_10 = a1 * image_seq_len + b1
|
||||
|
||||
a = (m_200 - m_10) / 190.0
|
||||
b = m_200 - 200.0 * a
|
||||
mu = a * num_steps + b
|
||||
|
||||
return float(mu)
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`list[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`list[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
|
||||
r"""
|
||||
The Flux2 Klein KV pipeline for text-to-image generation with KV-cached reference image conditioning.
|
||||
|
||||
On the first denoising step, reference image tokens are included in the forward pass and their attention K/V
|
||||
projections are cached. On subsequent steps, the cached K/V are reused without recomputing, providing faster
|
||||
inference when using reference images.
|
||||
|
||||
Reference:
|
||||
[https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence](https://bfl.ai/blog/flux2-klein-towards-interactive-visual-intelligence)
|
||||
|
||||
Args:
|
||||
transformer ([`Flux2Transformer2DModel`]):
|
||||
Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLFlux2`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`Qwen3ForCausalLM`]):
|
||||
[Qwen3ForCausalLM](https://huggingface.co/docs/transformers/en/model_doc/qwen3#transformers.Qwen3ForCausalLM)
|
||||
tokenizer (`Qwen2TokenizerFast`):
|
||||
Tokenizer of class
|
||||
[Qwen2TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/qwen2#transformers.Qwen2TokenizerFast).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKLFlux2,
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
transformer: Flux2Transformer2DModel,
|
||||
is_distilled: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
||||
# Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
|
||||
# by the patch size. So the vae scale factor is multiplied by the patch size to account for this
|
||||
self.image_processor = Flux2ImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
self.tokenizer_max_length = 512
|
||||
self.default_sample_size = 128
|
||||
|
||||
# Set KV-cache-aware attention processors
|
||||
self._set_kv_attn_processors()
|
||||
|
||||
@staticmethod
|
||||
def _get_qwen3_prompt_embeds(
|
||||
text_encoder: Qwen3ForCausalLM,
|
||||
tokenizer: Qwen2TokenizerFast,
|
||||
prompt: str | list[str],
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
hidden_states_layers: list[int] = (9, 18, 27),
|
||||
):
|
||||
dtype = text_encoder.dtype if dtype is None else dtype
|
||||
device = text_encoder.device if device is None else device
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
all_input_ids = []
|
||||
all_attention_masks = []
|
||||
|
||||
for single_prompt in prompt:
|
||||
messages = [{"role": "user", "content": single_prompt}]
|
||||
text = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=False,
|
||||
)
|
||||
inputs = tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
max_length=max_sequence_length,
|
||||
)
|
||||
|
||||
all_input_ids.append(inputs["input_ids"])
|
||||
all_attention_masks.append(inputs["attention_mask"])
|
||||
|
||||
input_ids = torch.cat(all_input_ids, dim=0).to(device)
|
||||
attention_mask = torch.cat(all_attention_masks, dim=0).to(device)
|
||||
|
||||
# Forward pass through the model
|
||||
output = text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
use_cache=False,
|
||||
)
|
||||
|
||||
# Only use outputs from intermediate layers and stack them
|
||||
out = torch.stack([output.hidden_states[k] for k in hidden_states_layers], dim=1)
|
||||
out = out.to(dtype=dtype, device=device)
|
||||
|
||||
batch_size, num_channels, seq_len, hidden_dim = out.shape
|
||||
prompt_embeds = out.permute(0, 2, 1, 3).reshape(batch_size, seq_len, num_channels * hidden_dim)
|
||||
|
||||
return prompt_embeds
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_text_ids
|
||||
def _prepare_text_ids(
|
||||
x: torch.Tensor, # (B, L, D) or (L, D)
|
||||
t_coord: torch.Tensor | None = None,
|
||||
):
|
||||
B, L, _ = x.shape
|
||||
out_ids = []
|
||||
|
||||
for i in range(B):
|
||||
t = torch.arange(1) if t_coord is None else t_coord[i]
|
||||
h = torch.arange(1)
|
||||
w = torch.arange(1)
|
||||
l = torch.arange(L)
|
||||
|
||||
coords = torch.cartesian_prod(t, h, w, l)
|
||||
out_ids.append(coords)
|
||||
|
||||
return torch.stack(out_ids)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_latent_ids
|
||||
def _prepare_latent_ids(
|
||||
latents: torch.Tensor, # (B, C, H, W)
|
||||
):
|
||||
r"""
|
||||
Generates 4D position coordinates (T, H, W, L) for latent tensors.
|
||||
|
||||
Args:
|
||||
latents (torch.Tensor):
|
||||
Latent tensor of shape (B, C, H, W)
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
Position IDs tensor of shape (B, H*W, 4) All batches share the same coordinate structure: T=0,
|
||||
H=[0..H-1], W=[0..W-1], L=0
|
||||
"""
|
||||
|
||||
batch_size, _, height, width = latents.shape
|
||||
|
||||
t = torch.arange(1) # [0] - time dimension
|
||||
h = torch.arange(height)
|
||||
w = torch.arange(width)
|
||||
l = torch.arange(1) # [0] - layer dimension
|
||||
|
||||
# Create position IDs: (H*W, 4)
|
||||
latent_ids = torch.cartesian_prod(t, h, w, l)
|
||||
|
||||
# Expand to batch: (B, H*W, 4)
|
||||
latent_ids = latent_ids.unsqueeze(0).expand(batch_size, -1, -1)
|
||||
|
||||
return latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._prepare_image_ids
|
||||
def _prepare_image_ids(
|
||||
image_latents: list[torch.Tensor], # [(1, C, H, W), (1, C, H, W), ...]
|
||||
scale: int = 10,
|
||||
):
|
||||
r"""
|
||||
Generates 4D time-space coordinates (T, H, W, L) for a sequence of image latents.
|
||||
|
||||
This function creates a unique coordinate for every pixel/patch across all input latent with different
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
image_latents (list[torch.Tensor]):
|
||||
A list of image latent feature tensors, typically of shape (C, H, W).
|
||||
scale (int, optional):
|
||||
A factor used to define the time separation (T-coordinate) between latents. T-coordinate for the i-th
|
||||
latent is: 'scale + scale * i'. Defaults to 10.
|
||||
|
||||
Returns:
|
||||
torch.Tensor:
|
||||
The combined coordinate tensor. Shape: (1, N_total, 4) Where N_total is the sum of (H * W) for all
|
||||
input latents.
|
||||
|
||||
Coordinate Components (Dimension 4):
|
||||
- T (Time): The unique index indicating which latent image the coordinate belongs to.
|
||||
- H (Height): The row index within that latent image.
|
||||
- W (Width): The column index within that latent image.
|
||||
- L (Seq. Length): A sequence length dimension, which is always fixed at 0 (size 1)
|
||||
"""
|
||||
|
||||
if not isinstance(image_latents, list):
|
||||
raise ValueError(f"Expected `image_latents` to be a list, got {type(image_latents)}.")
|
||||
|
||||
# create time offset for each reference image
|
||||
t_coords = [scale + scale * t for t in torch.arange(0, len(image_latents))]
|
||||
t_coords = [t.view(-1) for t in t_coords]
|
||||
|
||||
image_latent_ids = []
|
||||
for x, t in zip(image_latents, t_coords):
|
||||
x = x.squeeze(0)
|
||||
_, height, width = x.shape
|
||||
|
||||
x_ids = torch.cartesian_prod(t, torch.arange(height), torch.arange(width), torch.arange(1))
|
||||
image_latent_ids.append(x_ids)
|
||||
|
||||
image_latent_ids = torch.cat(image_latent_ids, dim=0)
|
||||
image_latent_ids = image_latent_ids.unsqueeze(0)
|
||||
|
||||
return image_latent_ids
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._patchify_latents
|
||||
def _patchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4)
|
||||
latents = latents.reshape(batch_size, num_channels_latents * 4, height // 2, width // 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpatchify_latents
|
||||
def _unpatchify_latents(latents):
|
||||
batch_size, num_channels_latents, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), 2, 2, height, width)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3)
|
||||
latents = latents.reshape(batch_size, num_channels_latents // (2 * 2), height * 2, width * 2)
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._pack_latents
|
||||
def _pack_latents(latents):
|
||||
"""
|
||||
pack latents: (batch_size, num_channels, height, width) -> (batch_size, height * width, num_channels)
|
||||
"""
|
||||
|
||||
batch_size, num_channels, height, width = latents.shape
|
||||
latents = latents.reshape(batch_size, num_channels, height * width).permute(0, 2, 1)
|
||||
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
|
||||
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
|
||||
"""
|
||||
using position ids to scatter tokens into place
|
||||
"""
|
||||
x_list = []
|
||||
for data, pos in zip(x, x_ids):
|
||||
_, ch = data.shape # noqa: F841
|
||||
h_ids = pos[:, 1].to(torch.int64)
|
||||
w_ids = pos[:, 2].to(torch.int64)
|
||||
|
||||
h = torch.max(h_ids) + 1
|
||||
w = torch.max(w_ids) + 1
|
||||
|
||||
flat_ids = h_ids * w + w_ids
|
||||
|
||||
out = torch.zeros((h * w, ch), device=data.device, dtype=data.dtype)
|
||||
out.scatter_(0, flat_ids.unsqueeze(1).expand(-1, ch), data)
|
||||
|
||||
# reshape from (H * W, C) to (H, W, C) and permute to (C, H, W)
|
||||
|
||||
out = out.view(h, w, ch).permute(2, 0, 1)
|
||||
x_list.append(out)
|
||||
|
||||
return torch.stack(x_list, dim=0)
|
||||
|
||||
def _set_kv_attn_processors(self):
|
||||
"""Replace default attention processors with KV-cache-aware variants."""
|
||||
for block in self.transformer.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in self.transformer.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt is None:
|
||||
prompt = ""
|
||||
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt_embeds = self._get_qwen3_prompt_embeds(
|
||||
text_encoder=self.text_encoder,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
hidden_states_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
batch_size, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
text_ids = self._prepare_text_ids(prompt_embeds)
|
||||
text_ids = text_ids.to(device)
|
||||
return prompt_embeds, text_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._encode_vae_image
|
||||
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
|
||||
if image.ndim != 4:
|
||||
raise ValueError(f"Expected image dims 4, got {image.ndim}.")
|
||||
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="argmax")
|
||||
image_latents = self._patchify_latents(image_latents)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps)
|
||||
image_latents = (image_latents - latents_bn_mean) / latents_bn_std
|
||||
|
||||
return image_latents
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_latents
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
num_latents_channels,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator: torch.Generator,
|
||||
latents: torch.Tensor | None = None,
|
||||
):
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_latents_channels * 4, height // 2, width // 2)
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
latent_ids = self._prepare_latent_ids(latents)
|
||||
latent_ids = latent_ids.to(device)
|
||||
|
||||
latents = self._pack_latents(latents) # [B, C, H, W] -> [B, H*W, C]
|
||||
return latents, latent_ids
|
||||
|
||||
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline.prepare_image_latents
|
||||
def prepare_image_latents(
|
||||
self,
|
||||
images: list[torch.Tensor],
|
||||
batch_size,
|
||||
generator: torch.Generator,
|
||||
device,
|
||||
dtype,
|
||||
):
|
||||
image_latents = []
|
||||
for image in images:
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
imagge_latent = self._encode_vae_image(image=image, generator=generator)
|
||||
image_latents.append(imagge_latent) # (1, 128, 32, 32)
|
||||
|
||||
image_latent_ids = self._prepare_image_ids(image_latents)
|
||||
|
||||
# Pack each latent and concatenate
|
||||
packed_latents = []
|
||||
for latent in image_latents:
|
||||
# latent: (1, 128, 32, 32)
|
||||
packed = self._pack_latents(latent) # (1, 1024, 128)
|
||||
packed = packed.squeeze(0) # (1024, 128) - remove batch dim
|
||||
packed_latents.append(packed)
|
||||
|
||||
# Concatenate all reference tokens along sequence dimension
|
||||
image_latents = torch.cat(packed_latents, dim=0) # (N*1024, 128)
|
||||
image_latents = image_latents.unsqueeze(0) # (1, N*1024, 128)
|
||||
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.repeat(batch_size, 1, 1)
|
||||
image_latent_ids = image_latent_ids.to(device)
|
||||
|
||||
return image_latents, image_latent_ids
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if (
|
||||
height is not None
|
||||
and height % (self.vae_scale_factor * 2) != 0
|
||||
or width is not None
|
||||
and width % (self.vae_scale_factor * 2) != 0
|
||||
):
|
||||
logger.warning(
|
||||
f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
@property
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def current_timestep(self):
|
||||
return self._current_timestep
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: list[PIL.Image.Image] | PIL.Image.Image | None = None,
|
||||
prompt: str | list[str] = None,
|
||||
height: int | None = None,
|
||||
width: int | None = None,
|
||||
num_inference_steps: int = 4,
|
||||
sigmas: list[float] | None = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
callback_on_step_end: Callable[[int, int, dict], None] | None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
text_encoder_out_layers: tuple[int] = (9, 18, 27),
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image` or `List[PIL.Image.Image]`, *optional*):
|
||||
Reference image(s) for conditioning. On the first denoising step, reference tokens are included in the
|
||||
forward pass and their attention K/V are cached. On subsequent steps, the cached K/V are reused without
|
||||
recomputing.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 4):
|
||||
The number of denoising steps.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas for the denoising schedule.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
Generator(s) for deterministic generation.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
Output format: `"pil"` or `"np"`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Flux2PipelineOutput` or a plain tuple.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Extra kwargs passed to attention processors.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
Callback function called at the end of each denoising step.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
Tensor inputs for the callback function.
|
||||
max_sequence_length (`int`, defaults to 512):
|
||||
Maximum sequence length for the prompt.
|
||||
text_encoder_out_layers (`tuple[int]`):
|
||||
Layer indices for text encoder hidden state extraction.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.flux2.Flux2PipelineOutput`] or `tuple`.
|
||||
"""
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
prompt_embeds=prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._current_timestep = None
|
||||
self._interrupt = False
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. prepare text embeddings
|
||||
prompt_embeds, text_ids = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
device=device,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
text_encoder_out_layers=text_encoder_out_layers,
|
||||
)
|
||||
|
||||
# 4. process images
|
||||
if image is not None and not isinstance(image, list):
|
||||
image = [image]
|
||||
|
||||
condition_images = None
|
||||
if image is not None:
|
||||
for img in image:
|
||||
self.image_processor.check_image_input(img)
|
||||
|
||||
condition_images = []
|
||||
for img in image:
|
||||
image_width, image_height = img.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
img = self.image_processor._resize_to_target_area(img, 1024 * 1024)
|
||||
image_width, image_height = img.size
|
||||
|
||||
multiple_of = self.vae_scale_factor * 2
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
img = self.image_processor.preprocess(img, height=image_height, width=image_width, resize_mode="crop")
|
||||
condition_images.append(img)
|
||||
height = height or image_height
|
||||
width = width or image_width
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
# 5. prepare latent variables
|
||||
num_channels_latents = self.transformer.config.in_channels // 4
|
||||
latents, latent_ids = self.prepare_latents(
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_latents_channels=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
dtype=prompt_embeds.dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
image_latents = None
|
||||
image_latent_ids = None
|
||||
if condition_images is not None:
|
||||
image_latents, image_latent_ids = self.prepare_image_latents(
|
||||
images=condition_images,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
generator=generator,
|
||||
device=device,
|
||||
dtype=self.vae.dtype,
|
||||
)
|
||||
|
||||
# 6. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
if hasattr(self.scheduler.config, "use_flow_sigmas") and self.scheduler.config.use_flow_sigmas:
|
||||
sigmas = None
|
||||
image_seq_len = latents.shape[1]
|
||||
mu = compute_empirical_mu(image_seq_len=image_seq_len, num_steps=num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 7. Denoising loop with KV caching
|
||||
# Step 0 with ref images: forward_kv_extract (full pass, cache ref K/V)
|
||||
# Steps 1+: forward_kv_cached (reuse cached ref K/V)
|
||||
# No ref images: standard forward
|
||||
self.scheduler.set_begin_index(0)
|
||||
kv_cache = None
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t
|
||||
timestep = t.expand(latents.shape[0]).to(latents.dtype)
|
||||
|
||||
if i == 0 and image_latents is not None:
|
||||
# Step 0: include ref tokens, extract KV cache
|
||||
latent_model_input = torch.cat([image_latents, latents], dim=1).to(self.transformer.dtype)
|
||||
latent_image_ids = torch.cat([image_latent_ids, latent_ids], dim=1)
|
||||
|
||||
noise_pred, kv_cache = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_image_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=image_latents.shape[1],
|
||||
)
|
||||
|
||||
elif kv_cache is not None:
|
||||
# Steps 1+: use cached ref KV, no ref tokens in input
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
kv_cache=kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)[0]
|
||||
|
||||
else:
|
||||
# No reference images: standard forward
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents.to(self.transformer.dtype),
|
||||
timestep=timestep / 1000,
|
||||
guidance=None,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
txt_ids=text_ids,
|
||||
img_ids=latent_ids,
|
||||
joint_attention_kwargs=self.attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents_dtype = latents.dtype
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# Clean up KV cache
|
||||
if kv_cache is not None:
|
||||
kv_cache.clear()
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
latents = self._unpack_latents_with_ids(latents, latent_ids)
|
||||
|
||||
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
latents = latents * latents_bn_std + latents_bn_mean
|
||||
latents = self._unpatchify_latents(latents)
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
else:
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return Flux2PipelineOutput(images=image)
|
||||
@@ -28,7 +28,7 @@ else:
|
||||
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_ltx2_condition import LTX2ConditionPipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -9,6 +11,79 @@ from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
def per_layer_masked_mean_norm(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
"""
|
||||
Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states.
|
||||
Respects the padding of the hidden states.
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
|
||||
def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
||||
variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True)
|
||||
norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps)
|
||||
return norm_text_encoder_hidden_states
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
@@ -106,6 +181,7 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -115,8 +191,9 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
@@ -160,6 +237,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -188,6 +266,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -260,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size
|
||||
text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B
|
||||
video_connector_num_attention_heads: int = 30,
|
||||
video_connector_attention_head_dim: int = 128,
|
||||
video_connector_num_layers: int = 2,
|
||||
video_connector_num_learnable_registers: int | None = 128,
|
||||
video_gated_attn: bool = False,
|
||||
audio_connector_num_attention_heads: int = 30,
|
||||
audio_connector_attention_head_dim: int = 128,
|
||||
audio_connector_num_layers: int = 2,
|
||||
audio_connector_num_learnable_registers: int | None = 128,
|
||||
audio_gated_attn: bool = False,
|
||||
connector_rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
per_modality_projections: bool = False,
|
||||
video_hidden_dim: int = 4096,
|
||||
audio_hidden_dim: int = 2048,
|
||||
proj_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
text_encoder_dim = caption_channels * text_proj_in_factor
|
||||
if per_modality_projections:
|
||||
self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias)
|
||||
self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias)
|
||||
else:
|
||||
self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias)
|
||||
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
@@ -288,6 +379,7 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=video_gated_attn,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
@@ -299,26 +391,86 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=audio_gated_attn,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
self,
|
||||
text_encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text
|
||||
embeddings for the LTX-2.X DiT models.
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
Args:
|
||||
text_encoder_hidden_states (`torch.Tensor`)):
|
||||
Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len,
|
||||
caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
||||
Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked
|
||||
positions.
|
||||
padding_side (`str`, *optional*, defaults to `"left"`):
|
||||
The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to
|
||||
`"left"` as this is what the default Gemma3-12B text encoder uses. Only used if
|
||||
`per_modality_projections` is `False` (LTX-2.0 models).
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False`
|
||||
(LTX-2.0 models).
|
||||
"""
|
||||
if text_encoder_hidden_states.ndim == 3:
|
||||
# Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor]
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1))
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
if self.config.per_modality_projections:
|
||||
# LTX-2.3
|
||||
norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3)
|
||||
bool_mask = attention_mask.bool().unsqueeze(-1)
|
||||
norm_text_encoder_hidden_states = torch.where(
|
||||
bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states)
|
||||
)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
# Rescale norms with respect to video and audio dims for feature extractors
|
||||
video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels)
|
||||
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
|
||||
audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels)
|
||||
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
# Per-Modality Feature extractors
|
||||
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
|
||||
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
|
||||
else:
|
||||
# LTX-2.0
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
norm_text_encoder_hidden_states = per_layer_masked_mean_norm(
|
||||
text_hidden_states=text_encoder_hidden_states,
|
||||
sequence_lengths=sequence_lengths,
|
||||
device=text_encoder_hidden_states.device,
|
||||
padding_side=padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states)
|
||||
video_text_emb_proj = text_emb_proj
|
||||
audio_text_emb_proj = text_emb_proj
|
||||
|
||||
# Convert to additive attention mask for connectors
|
||||
text_dtype = video_text_emb_proj.dtype
|
||||
attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype)
|
||||
attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
add_attn_mask = attention_mask * torch.finfo(text_dtype).max
|
||||
|
||||
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask)
|
||||
|
||||
# Convert video attn mask to binary (multiplicative) mask and mask video text embedding
|
||||
binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64)
|
||||
binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * binary_attn_mask
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1)
|
||||
|
||||
@@ -195,7 +195,8 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: float | None = 2.0,
|
||||
rational_spatial_scale: float = 2.0,
|
||||
use_rational_resampler: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -220,7 +221,7 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_spatial_scale is not None:
|
||||
if use_rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
|
||||
@@ -31,7 +31,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -209,7 +209,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_optional_components = ["processor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -221,7 +221,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -234,6 +235,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -268,73 +270,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
@@ -387,16 +322,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -494,6 +420,46 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -504,6 +470,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -547,6 +516,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
@@ -757,9 +732,41 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -791,7 +798,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -803,6 +817,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -841,13 +860,47 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -878,6 +931,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -910,6 +981,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -920,10 +996,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -939,6 +1026,16 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -960,9 +1057,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1069,11 +1168,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1111,8 +1205,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1120,7 +1217,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1128,24 +1228,152 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_rescale = self.guidance_rescale
|
||||
cond_std = video_cond_x0.std(dim=list(range(1, video_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = video_guided_x0.std(dim=list(range(1, video_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale)
|
||||
video_guided_x0 = video_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_rescale = self.audio_guidance_rescale
|
||||
cond_std = audio_cond_x0.std(dim=list(range(1, audio_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = audio_guided_x0.std(dim=list(range(1, audio_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale)
|
||||
audio_guided_x0 = audio_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -254,7 +254,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -300,74 +300,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -421,16 +353,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -541,6 +464,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=None,
|
||||
latents=None,
|
||||
audio_latents=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -597,6 +523,12 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
f" using the `_unpack_audio_latents` method)."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -992,9 +924,41 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -1027,7 +991,14 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[float] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float | None = None,
|
||||
num_videos_per_prompt: int | None = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -1039,6 +1010,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -1079,13 +1051,47 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the
|
||||
@@ -1117,6 +1123,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -1149,6 +1159,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -1161,10 +1176,21 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
latents=latents,
|
||||
audio_latents=audio_latents,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -1208,9 +1234,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1301,11 +1329,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1344,8 +1367,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1353,7 +1379,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1361,24 +1390,155 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_rescale = self.guidance_rescale
|
||||
cond_std = video_cond_x0.std(dim=list(range(1, video_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = video_guided_x0.std(dim=list(range(1, video_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale)
|
||||
video_guided_x0 = video_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_rescale = self.audio_guidance_rescale
|
||||
cond_std = audio_cond_x0.std(dim=list(range(1, audio_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = audio_guided_x0.std(dim=list(range(1, audio_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale)
|
||||
audio_guided_x0 = audio_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG
|
||||
bsz = noise_pred_video.size(0)
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
@@ -32,7 +32,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -212,7 +212,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_optional_components = ["processor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -224,7 +224,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -237,6 +238,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -271,74 +273,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -392,16 +326,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -500,6 +425,53 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
|
||||
],
|
||||
},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -511,6 +483,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -554,6 +529,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -811,9 +792,41 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -846,7 +859,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -858,6 +878,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -898,13 +923,47 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -935,6 +994,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -967,6 +1044,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -977,10 +1059,21 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -996,6 +1089,17 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1017,9 +1121,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1134,11 +1240,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1177,8 +1278,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
@@ -1186,7 +1290,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1194,24 +1301,155 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
)
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
video_guided_x0 = latents - noise_pred_video_g * self.scheduler.sigmas[i]
|
||||
video_cond_x0 = latents - noise_pred_video * self.scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
video_rescale = self.guidance_rescale
|
||||
cond_std = video_cond_x0.std(dim=list(range(1, video_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = video_guided_x0.std(dim=list(range(1, video_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = video_rescale * (cond_std / guided_std) + (1 - video_rescale)
|
||||
video_guided_x0 = video_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_video = (latents - video_guided_x0) / self.scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
# Convert from velocity to sample (x0) prediction
|
||||
audio_guided_x0 = audio_latents - noise_pred_audio_g * audio_scheduler.sigmas[i]
|
||||
audio_cond_x0 = audio_latents - noise_pred_audio * audio_scheduler.sigmas[i]
|
||||
|
||||
# Apply guidance rescaling in sample (x0) space, following original code
|
||||
audio_rescale = self.audio_guidance_rescale
|
||||
cond_std = audio_cond_x0.std(dim=list(range(1, audio_cond_x0.ndim)), keepdim=True)
|
||||
guided_std = audio_guided_x0.std(dim=list(range(1, audio_guided_x0.ndim)), keepdim=True)
|
||||
rescale_factor = audio_rescale * (cond_std / guided_std) + (1 - audio_rescale)
|
||||
audio_guided_x0 = audio_guided_x0 * rescale_factor
|
||||
|
||||
# Convert back to velocity space for scheduler
|
||||
noise_pred_audio = (audio_latents - audio_guided_x0) / audio_scheduler.sigmas[i]
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred_video = self._unpack_latents(
|
||||
|
||||
@@ -8,6 +8,209 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Creates a Kaiser sinc kernel for low-pass filtering.
|
||||
|
||||
Args:
|
||||
cutoff (`float`):
|
||||
Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist
|
||||
frequency).
|
||||
half_width (`float`):
|
||||
Used to determine the Kaiser window's beta parameter.
|
||||
kernel_size:
|
||||
Size of the Kaiser window (and ultimately the Kaiser sinc kernel).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(kernel_size,)`:
|
||||
The Kaiser sinc kernel.
|
||||
"""
|
||||
delta_f = 4 * half_width
|
||||
half_size = kernel_size // 2
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
even = kernel_size % 2 == 0
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
|
||||
if cutoff == 0.0:
|
||||
filter = torch.zeros_like(time)
|
||||
else:
|
||||
time = 2 * cutoff * time
|
||||
sinc = torch.where(
|
||||
time == 0,
|
||||
torch.ones_like(time),
|
||||
torch.sin(math.pi * time) / math.pi / time,
|
||||
)
|
||||
filter = 2 * cutoff * window * sinc
|
||||
filter = filter / filter.sum()
|
||||
return filter
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
"""1D low-pass filter for antialias downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
use_padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = kernel_size or int(6 * ratio // 2) * 2
|
||||
self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1
|
||||
self.pad_right = self.kernel_size // 2
|
||||
self.use_padding = use_padding
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
cutoff = 0.5 / ratio
|
||||
half_width = 0.6 / ratio
|
||||
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
|
||||
self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
if self.use_padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels)
|
||||
return x_filtered
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
window_type: str = "kaiser",
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
if window_type == "hann":
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser sinc filter is BigVGAN default
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2
|
||||
self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2
|
||||
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode)
|
||||
low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class AntiAliasAct1d(nn.Module):
|
||||
"""
|
||||
Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples
|
||||
to avoid aliasing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_fn: str | nn.Module,
|
||||
ratio: int = 2,
|
||||
kernel_size: int = 12,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "snakebeta":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
elif act_fn == "snake":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(**kwargs)
|
||||
self.act = act_fn
|
||||
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
Implements the Snake and SnakeBeta activations, which help with learning periodic patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
alpha: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
trainable_params: bool = True,
|
||||
logscale: bool = True,
|
||||
use_beta: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.logscale = logscale
|
||||
self.use_beta = use_beta
|
||||
|
||||
self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.alpha.requires_grad = trainable_params
|
||||
if use_beta:
|
||||
self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.beta.requires_grad = trainable_params
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
|
||||
broadcast_shape = [1] * hidden_states.ndim
|
||||
broadcast_shape[channel_dim] = -1
|
||||
alpha = self.alpha.view(broadcast_shape)
|
||||
if self.use_beta:
|
||||
beta = self.beta.view(broadcast_shape)
|
||||
|
||||
if self.logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
if self.use_beta:
|
||||
beta = torch.exp(beta)
|
||||
|
||||
amplitude = beta if self.use_beta else alpha
|
||||
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -15,12 +218,15 @@ class ResBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: tuple[int, ...] = (1, 3, 5),
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
@@ -28,6 +234,18 @@ class ResBlock(nn.Module):
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
self.acts1 = nn.ModuleList()
|
||||
for _ in range(len(self.convs1)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts1.append(act)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
@@ -35,12 +253,24 @@ class ResBlock(nn.Module):
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
self.acts2 = nn.ModuleList()
|
||||
for _ in range(len(self.convs2)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts2.append(act)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2):
|
||||
xt = act1(x)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = act2(xt)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
@@ -61,7 +291,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = "tanh", # tanh, clamp, None
|
||||
final_bias: bool = True,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -69,7 +305,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.act_fn = act_fn
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
self.final_act_fn = final_act_fn
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
@@ -83,6 +321,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
supported_act_fns = ["snakebeta", "snake", "leaky_relu"]
|
||||
if self.act_fn not in supported_act_fns:
|
||||
raise ValueError(
|
||||
f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are "
|
||||
f"{supported_act_fns}."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
@@ -103,15 +348,27 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
channels=output_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilations=dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
if act_fn == "snakebeta" or act_fn == "snake":
|
||||
# Always use antialiasing
|
||||
act_out = SnakeBeta(channels=output_channels, use_beta=True)
|
||||
self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
elif act_fn == "leaky_relu":
|
||||
# NOTE: does NOT use self.negative_slope, following the original code
|
||||
self.act_out = nn.LeakyReLU()
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -139,7 +396,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
if self.act_fn == "leaky_relu":
|
||||
# Other activations are inside each upsampling block
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
@@ -149,10 +408,190 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.act_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
if self.final_act_fn == "tanh":
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
elif self.final_act_fn == "clamp":
|
||||
hidden_states = torch.clamp(hidden_states, -1, 1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalSTFT(nn.Module):
|
||||
"""
|
||||
Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases
|
||||
multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact
|
||||
buffers should be loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.window_length = window_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples]
|
||||
|
||||
left_pad = max(0, self.window_length - self.hop_length) # causal: left-only
|
||||
waveform = F.pad(waveform, (left_pad, 0))
|
||||
|
||||
spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""
|
||||
Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be
|
||||
loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = CausalSTFT(filter_length, hop_length, window_length)
|
||||
|
||||
num_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
magnitude, phase = self.stft_fn(waveform)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the
|
||||
BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same
|
||||
architecture as the original vocoder.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1536,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4],
|
||||
upsample_factors: list[int] = [5, 2, 2, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "snakebeta",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = True,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = None,
|
||||
final_bias: bool = False,
|
||||
bwe_in_channels: int = 128,
|
||||
bwe_hidden_channels: int = 512,
|
||||
bwe_out_channels: int = 2,
|
||||
bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4],
|
||||
bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
bwe_resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
bwe_act_fn: str = "snakebeta",
|
||||
bwe_leaky_relu_negative_slope: float = 0.1,
|
||||
bwe_antialias: bool = True,
|
||||
bwe_antialias_ratio: int = 2,
|
||||
bwe_antialias_kernel_size: int = 12,
|
||||
bwe_final_act_fn: str | None = None,
|
||||
bwe_final_bias: bool = False,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocoder = LTX2Vocoder(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
out_channels=out_channels,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
upsample_factors=upsample_factors,
|
||||
resnet_kernel_sizes=resnet_kernel_sizes,
|
||||
resnet_dilations=resnet_dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
final_act_fn=final_act_fn,
|
||||
final_bias=final_bias,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
in_channels=bwe_in_channels,
|
||||
hidden_channels=bwe_hidden_channels,
|
||||
out_channels=bwe_out_channels,
|
||||
upsample_kernel_sizes=bwe_upsample_kernel_sizes,
|
||||
upsample_factors=bwe_upsample_factors,
|
||||
resnet_kernel_sizes=bwe_resnet_kernel_sizes,
|
||||
resnet_dilations=bwe_resnet_dilations,
|
||||
act_fn=bwe_act_fn,
|
||||
leaky_relu_negative_slope=bwe_leaky_relu_negative_slope,
|
||||
antialias=bwe_antialias,
|
||||
antialias_ratio=bwe_antialias_ratio,
|
||||
antialias_kernel_size=bwe_antialias_kernel_size,
|
||||
final_act_fn=bwe_final_act_fn,
|
||||
final_bias=bwe_final_bias,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=filter_length,
|
||||
hop_length=hop_length,
|
||||
window_length=window_length,
|
||||
num_mel_channels=num_mel_channels,
|
||||
)
|
||||
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate,
|
||||
window_type="hann",
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
# 1. Run stage 1 vocoder to get low sampling rate waveform
|
||||
x = self.vocoder(mel_spec)
|
||||
batch_size, num_channels, num_samples = x.shape
|
||||
|
||||
# Pad to exact multiple of hop_length for exact mel frame count
|
||||
remainder = num_samples % self.config.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# 2. Compute mel spectrogram on vocoder output
|
||||
mel, _, _, _ = self.mel_stft(x.flatten(0, 1))
|
||||
mel = mel.unflatten(0, (-1, num_channels))
|
||||
|
||||
# 3. Run bandwidth extender (BWE) on new mel spectrogram
|
||||
mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins]
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
|
||||
# 4. Residual connection with resampler
|
||||
skip = self.resampler(x)
|
||||
waveform = torch.clamp(residual + skip, -1, 1)
|
||||
output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate
|
||||
waveform = waveform[..., :output_samples]
|
||||
return waveform
|
||||
|
||||
@@ -36,7 +36,7 @@ from typing import Any, Callable
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
from ..utils import deprecate, is_torch_available, is_torchao_available, is_torchao_version, logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -844,6 +844,8 @@ class QuantoConfig(QuantizationConfigMixin):
|
||||
modules_to_not_convert: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
deprecation_message = "`QuantoConfig` is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoConfig", "1.0.0", deprecation_message)
|
||||
self.quant_method = QuantizationMethod.QUANTO
|
||||
self.weights_dtype = weights_dtype
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from diffusers.utils.import_utils import is_optimum_quanto_version
|
||||
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
get_module_from_name,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
@@ -42,6 +43,9 @@ class QuantoQuantizer(DiffusersQuantizer):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
deprecation_message = "The Quanto quantizer is deprecated and will be removed in version 1.0.0."
|
||||
deprecate("QuantoQuantizer", "1.0.0", deprecation_message)
|
||||
|
||||
if not is_optimum_quanto_available():
|
||||
raise ImportError(
|
||||
"Loading an optimum-quanto quantized model requires optimum-quanto library (`pip install optimum-quanto`)"
|
||||
|
||||
@@ -1202,6 +1202,21 @@ class EasyAnimatePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinKVPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class Flux2KleinPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -60,12 +60,7 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
model.eval()
|
||||
|
||||
# Move inputs to device
|
||||
inputs_on_device = {}
|
||||
for key, value in inputs_dict.items():
|
||||
if isinstance(value, torch.Tensor):
|
||||
inputs_on_device[key] = value.to(device)
|
||||
else:
|
||||
inputs_on_device[key] = value
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# Enable context parallelism
|
||||
cp_config = ContextParallelConfig(**cp_dict)
|
||||
@@ -89,6 +84,59 @@ def _context_parallel_worker(rank, world_size, master_port, model_class, init_di
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def _custom_mesh_worker(
|
||||
rank,
|
||||
world_size,
|
||||
master_port,
|
||||
model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
):
|
||||
"""Worker function for context parallel testing with a user-provided custom DeviceMesh."""
|
||||
try:
|
||||
os.environ["MASTER_ADDR"] = "localhost"
|
||||
os.environ["MASTER_PORT"] = str(master_port)
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
|
||||
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
||||
|
||||
torch.cuda.set_device(rank)
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
|
||||
model = model_class(**init_dict)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
inputs_on_device = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
|
||||
|
||||
# DeviceMesh must be created after init_process_group, inside each worker process.
|
||||
mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
"cuda", mesh_shape=mesh_shape, mesh_dim_names=mesh_dim_names
|
||||
)
|
||||
cp_config = ContextParallelConfig(**cp_dict, mesh=mesh)
|
||||
model.enable_parallelism(config=cp_config)
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_on_device, return_dict=False)[0]
|
||||
|
||||
if rank == 0:
|
||||
return_dict["status"] = "success"
|
||||
return_dict["output_shape"] = list(output.shape)
|
||||
|
||||
except Exception as e:
|
||||
if rank == 0:
|
||||
return_dict["status"] = "error"
|
||||
return_dict["error"] = str(e)
|
||||
finally:
|
||||
if dist.is_initialized():
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@is_context_parallel
|
||||
@require_torch_multi_accelerator
|
||||
class ContextParallelTesterMixin:
|
||||
@@ -126,3 +174,48 @@ class ContextParallelTesterMixin:
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"cp_type,mesh_shape,mesh_dim_names",
|
||||
[
|
||||
("ring_degree", (2, 1, 1), ("ring", "ulysses", "fsdp")),
|
||||
("ulysses_degree", (1, 2, 1), ("ring", "ulysses", "fsdp")),
|
||||
],
|
||||
ids=["ring-3d-fsdp", "ulysses-3d-fsdp"],
|
||||
)
|
||||
def test_context_parallel_custom_mesh(self, cp_type, mesh_shape, mesh_dim_names):
|
||||
if not torch.distributed.is_available():
|
||||
pytest.skip("torch.distributed is not available.")
|
||||
|
||||
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
|
||||
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
|
||||
|
||||
world_size = 2
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}
|
||||
cp_dict = {cp_type: world_size}
|
||||
|
||||
master_port = _find_free_port()
|
||||
manager = mp.Manager()
|
||||
return_dict = manager.dict()
|
||||
|
||||
mp.spawn(
|
||||
_custom_mesh_worker,
|
||||
args=(
|
||||
world_size,
|
||||
master_port,
|
||||
self.model_class,
|
||||
init_dict,
|
||||
cp_dict,
|
||||
mesh_shape,
|
||||
mesh_dim_names,
|
||||
inputs_dict,
|
||||
return_dict,
|
||||
),
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
||||
|
||||
assert return_dict.get("status") == "success", (
|
||||
f"Custom mesh context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
|
||||
)
|
||||
|
||||
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
174
tests/pipelines/flux2/test_pipeline_flux2_klein_kv.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import Qwen2TokenizerFast, Qwen3Config, Qwen3ForCausalLM
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLFlux2,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
Flux2KleinKVPipeline,
|
||||
Flux2Transformer2DModel,
|
||||
)
|
||||
|
||||
from ...testing_utils import torch_device
|
||||
from ..test_pipelines_common import PipelineTesterMixin, check_qkv_fused_layers_exist
|
||||
|
||||
|
||||
class Flux2KleinKVPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = Flux2KleinKVPipeline
|
||||
params = frozenset(["prompt", "height", "width", "prompt_embeds", "image"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
supports_dduf = False
|
||||
|
||||
def get_dummy_components(self, num_layers: int = 1, num_single_layers: int = 1):
|
||||
torch.manual_seed(0)
|
||||
transformer = Flux2Transformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=16,
|
||||
num_attention_heads=2,
|
||||
joint_attention_dim=16,
|
||||
timestep_guidance_channels=256,
|
||||
axes_dims_rope=[4, 4, 4, 4],
|
||||
guidance_embeds=False,
|
||||
)
|
||||
|
||||
# Create minimal Qwen3 config
|
||||
config = Qwen3Config(
|
||||
intermediate_size=16,
|
||||
hidden_size=16,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=2,
|
||||
num_key_value_heads=2,
|
||||
vocab_size=151936,
|
||||
max_position_embeddings=512,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = Qwen3ForCausalLM(config)
|
||||
|
||||
# Use a simple tokenizer for testing
|
||||
tokenizer = Qwen2TokenizerFast.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration"
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKLFlux2(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=("DownEncoderBlock2D",),
|
||||
up_block_types=("UpDecoderBlock2D",),
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=1,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
)
|
||||
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
return {
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device="cpu").manual_seed(seed)
|
||||
|
||||
inputs = {
|
||||
"prompt": "a dog is dancing",
|
||||
"image": Image.new("RGB", (64, 64)),
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"height": 8,
|
||||
"width": 8,
|
||||
"max_sequence_length": 64,
|
||||
"output_type": "np",
|
||||
"text_encoder_out_layers": (1,),
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_fused_qkv_projections(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
original_image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.fuse_qkv_projections()
|
||||
self.assertTrue(
|
||||
check_qkv_fused_layers_exist(pipe.transformer, ["to_qkv"]),
|
||||
("Something wrong with the fused attention layers. Expected all the attention projections to be fused."),
|
||||
)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_fused = image[0, -3:, -3:, -1]
|
||||
|
||||
pipe.transformer.unfuse_qkv_projections()
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice_disabled = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_fused, atol=1e-3, rtol=1e-3),
|
||||
("Fusion of QKV projections shouldn't affect the outputs."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(image_slice_fused, image_slice_disabled, atol=1e-3, rtol=1e-3),
|
||||
("Outputs, with QKV projection fusion enabled, shouldn't change when fused QKV projections are disabled."),
|
||||
)
|
||||
self.assertTrue(
|
||||
np.allclose(original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2),
|
||||
("Original outputs should match when fused QKV projections are disabled."),
|
||||
)
|
||||
|
||||
def test_image_output_shape(self):
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device)
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
|
||||
height_width_pairs = [(32, 32), (72, 57)]
|
||||
for height, width in height_width_pairs:
|
||||
expected_height = height - height % (pipe.vae_scale_factor * 2)
|
||||
expected_width = width - width % (pipe.vae_scale_factor * 2)
|
||||
|
||||
inputs.update({"height": height, "width": width})
|
||||
image = pipe(**inputs).images[0]
|
||||
output_height, output_width, _ = image.shape
|
||||
self.assertEqual(
|
||||
(output_height, output_width),
|
||||
(expected_height, expected_width),
|
||||
f"Output shape {image.shape} does not match expected shape {(expected_height, expected_width)}",
|
||||
)
|
||||
|
||||
def test_without_image(self):
|
||||
device = "cpu"
|
||||
pipe = self.pipeline_class(**self.get_dummy_components()).to(device)
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
del inputs["image"]
|
||||
image = pipe(**inputs).images
|
||||
self.assertEqual(image.shape, (1, 8, 8, 3))
|
||||
|
||||
@unittest.skip("Needs to be revisited")
|
||||
def test_encode_prompt_works_in_isolation(self):
|
||||
pass
|
||||
@@ -171,6 +171,7 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
@@ -171,6 +171,7 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user