Compare commits

..

11 Commits

Author SHA1 Message Date
Sayak Paul
a9855c4204 [tests] fix audioldm2 tests. (#13293)
fix audioldm2 tests.
2026-03-20 20:53:21 +05:30
Sayak Paul
0b35834351 [core] fa4 support. (#13280)
* start fa4 support.

* up

* specify minimum version
2026-03-20 17:28:09 +05:30
Sayak Paul
522b523e40 [ci] hoping to fix is_flaky with wanvace. (#13294)
* hoping to fix is_flaky with wanvace.

* revert changes in src/diffusers/utils/testing_utils.py and propagate them to tests/testing_utils.py.

* up
2026-03-20 16:02:16 +05:30
Dhruv Nair
e9b9f25f67 [CI] Update transformer version in release tests (#13296)
update
2026-03-20 11:40:06 +05:30
Dhruv Nair
32b4cfc81c [Modular] Test for catching dtype and device issues with AutoModel type hints (#13287)
* update

* update

* update
2026-03-20 10:36:03 +05:30
YiYi Xu
a13e5cf9fc [agents]support skills (#13269)
* support skills

* update

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* update baSeed on new best practice

* Update .ai/skills/parity-testing/pitfalls.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* update

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-19 18:07:41 -10:00
dg845
072d15ee42 Add Support for LTX-2.3 Models (#13217)
* Initial implementation of perturbed attn processor for LTX 2.3

* Update DiT block for LTX 2.3 + add self_attention_mask

* Add flag to control using perturbed attn processor for now

* Add support for new video upsampling blocks used by LTX-2.3

* Support LTX-2.3 Big-VGAN V2-style vocoder

* Initial implementation of LTX-2.3 vocoder with bandwidth extender

* Initial support for LTX-2.3 per-modality feature extractor

* Refactor so that text connectors own all text encoder hidden_states normalization logic

* Fix some bugs for inference

* Fix LTX-2.X DiT block forward pass

* Support prompt timestep embeds and prompt cross attn modulation

* Add LTX-2.3 configs to conversion script

* Support converting LTX-2.3 DiT checkpoints

* Support converting LTX-2.3 Video VAE checkpoints

* Support converting LTX-2.3 Vocoder with bandwidth extender

* Support converting LTX-2.3 text connectors

* Don't convert any upsamplers for now

* Support self attention mask for LTX2Pipeline

* Fix some inference bugs

* Support self attn mask and sigmas for LTX-2.3 I2V, Cond pipelines

* Support STG and modality isolation guidance for LTX-2.3

* make style and make quality

* Make audio guidance values default to video values by default

* Update to LTX-2.3 style guidance rescaling

* Support cross timesteps for LTX-2.3 cross attention modulation

* Fix RMS norm bug for LTX-2.3 text connectors

* Perform guidance rescale in sample (x0) space following original code

* Support LTX-2.3 Latent Spatial Upsampler model

* Support LTX-2.3 distilled LoRA

* Support LTX-2.3 Distilled checkpoint

* Support LTX-2.3 prompt enhancement

* Make LTX-2.X processor non-required so that tests pass

* Fix test_components_function tests for LTX2 T2V and I2V

* Fix LTX-2.3 Video VAE configuration bug causing pixel jitter

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Refactor LTX-2.X Video VAE upsampler block init logic

* Refactor LTX-2.X guidance rescaling to use rescale_noise_cfg

* Use generator initial seed to control prompt enhancement if available

* Remove self attention mask logic as it is not used in any current pipelines

* Commit fixes suggested by claude code (guidance in sample (x0) space, denormalize after timestep conditioning)

* Use constant shift following original code

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-19 14:58:29 -07:00
kaixuanliu
67613369bb fix: 'PaintByExampleImageEncoder' object has no attribute 'all_tied_w… (#13252)
* fix: 'PaintByExampleImageEncoder' object has no attribute 'all_tied_weights_keys'

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* also fix LDMBertModel

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-18 17:55:08 -10:00
Shenghai Yuan
0c01a4b5e2 [Helios] Remove lru_cache for better AoTI compatibility and cleaner code (#13282)
fix: drop lru_cache for better AoTI compatibility
2026-03-18 23:41:58 +05:30
kaixuanliu
8e4b5607ed skip invalid test case for helios pipeline (#13218)
* skip invalid test case for helio pipeline

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

* update skip reason

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>

---------

Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-17 20:58:35 -10:00
Junsong Chen
c6f72ad2f6 add ltx2 vae in sana-video; (#13229)
* add ltx2 vae in sana-video;

* add ltx vae in conversion script;

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/sana_video/pipeline_sana_video.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* condition `vae_scale_factor_xxx` related settings on VAE types;

* make the mean/std depends on vae class;

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-17 18:09:52 -10:00
34 changed files with 3396 additions and 670 deletions

View File

@@ -24,54 +24,10 @@ Strive to write code as simple and explicit as possible.
### Models
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations.
- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
```python
# transformer_mymodel.py
## Skills
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Pipeline
- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references.
- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`).
### Tests
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).

View File

@@ -0,0 +1,167 @@
---
name: integrating-models
description: >
Use when adding a new model or pipeline to diffusers, setting up file
structure for a new model, converting a pipeline to modular format, or
converting weights for a new version of an already-supported model.
---
## Goal
Integrate a new model into diffusers end-to-end. The overall flow:
1. **Gather info** — ask the user for the reference repo, setup guide, a runnable inference script, and other objectives such as standard vs modular.
2. **Confirm the plan** — once you have everything, tell the user exactly what you'll do: e.g. "I'll integrate model X with pipeline Y into diffusers based on your script. I'll run parity tests (model-level and pipeline-level) using the `parity-testing` skill to verify numerical correctness against the reference."
3. **Implement** — write the diffusers code (model, pipeline, scheduler if needed), convert weights, register in `__init__.py`.
4. **Parity test** — use the `parity-testing` skill to verify component and e2e parity against the reference implementation.
5. **Deliver a unit test** — provide a self-contained test script that runs the diffusers implementation, checks numerical output (np allclose), and saves an image/video for visual verification. This is what the user runs to confirm everything works.
Work one workflow at a time — get it to full parity before moving on.
## Setup — gather before starting
Before writing any code, gather info in this order:
1. **Reference repo** — ask for the github link. If they've already set it up locally, ask for the path. Otherwise, ask what setup steps are needed (install deps, download checkpoints, set env vars, etc.) and run through them before proceeding.
2. **Inference script** — ask for a runnable end-to-end script for a basic workflow first (e.g. T2V). Then ask what other workflows they want to support (I2V, V2V, etc.) and agree on the full implementation order together.
3. **Standard vs modular** — standard pipelines, modular, or both?
Use `AskUserQuestion` with structured choices for step 3 when the options are known.
## Standard Pipeline Integration
### File structure for a new model
```
src/diffusers/
models/transformers/transformer_<model>.py # The core model
schedulers/scheduling_<model>.py # If model needs a custom scheduler
pipelines/<model>/
__init__.py
pipeline_<model>.py # Main pipeline
pipeline_<model>_<variant>.py # Variant pipelines (e.g. pyramid, distilled)
pipeline_output.py # Output dataclass
loaders/lora_pipeline.py # LoRA mixin (add to existing file)
tests/
models/transformers/test_models_transformer_<model>.py
pipelines/<model>/test_<model>.py
lora/test_lora_layers_<model>.py
docs/source/en/api/
pipelines/<model>.md
models/<model>_transformer3d.md # or appropriate name
```
### Integration checklist
- [ ] Implement transformer model with `from_pretrained` support
- [ ] Implement or reuse scheduler
- [ ] Implement pipeline(s) with `__call__` method
- [ ] Add LoRA support if applicable
- [ ] Register all classes in `__init__.py` files (lazy imports)
- [ ] Write unit tests (model, pipeline, LoRA)
- [ ] Write docs
- [ ] Run `make style` and `make quality`
- [ ] Test parity with reference implementation (see `parity-testing` skill)
### Attention pattern
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
```python
# transformer_mymodel.py
class MyModelAttnProcessor:
_attention_backend = None
_parallel_config = None
def __call__(self, attn, hidden_states, attention_mask=None, ...):
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
# reshape, apply rope, etc.
hidden_states = dispatch_attention_fn(
query, key, value,
attn_mask=attention_mask,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
return attn.to_out[0](hidden_states)
class MyModelAttention(nn.Module, AttentionModuleMixin):
_default_processor_cls = MyModelAttnProcessor
_available_processors = [MyModelAttnProcessor]
def __init__(self, query_dim, heads=8, dim_head=64, ...):
super().__init__()
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
self.set_processor(MyModelAttnProcessor())
def forward(self, hidden_states, attention_mask=None, **kwargs):
return self.processor(self, hidden_states, attention_mask, **kwargs)
```
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
### Implementation rules
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
### Test setup
- Slow tests gated with `@slow` and `RUN_SLOW=1`
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
### Common diffusers conventions
- Pipelines inherit from `DiffusionPipeline`
- Models use `ModelMixin` with `register_to_config` for config serialization
- Schedulers use `SchedulerMixin` with `ConfigMixin`
- Use `@torch.no_grad()` on pipeline `__call__`
- Support `output_type="latent"` for skipping VAE decode
- Support `generator` parameter for reproducibility
- Use `self.progress_bar(timesteps)` for progress tracking
## Gotchas
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
---
## Modular Pipeline Conversion
See [modular-conversion.md](modular-conversion.md) for the full guide on converting standard pipelines to modular format, including block types, build order, guider abstraction, and conversion checklist.
---
## Weight Conversion Tips
<!-- TODO: Add concrete examples as we encounter them. Common patterns to watch for:
- Fused QKV weights that need splitting into separate Q, K, V
- Scale/shift ordering differences (reference stores [shift, scale], diffusers expects [scale, shift])
- Weight transpositions (linear stored as transposed conv, or vice versa)
- Interleaved head dimensions that need reshaping
- Bias terms absorbed into different layers
Add each with a before/after code snippet showing the conversion. -->

View File

@@ -0,0 +1,152 @@
# Modular Pipeline Conversion Reference
## When to use
Modular pipelines break a monolithic `__call__` into composable blocks. Convert when:
- The model supports multiple workflows (T2V, I2V, V2V, etc.)
- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG)
- You want to share blocks across pipeline variants
## File structure
```
src/diffusers/modular_pipelines/<model>/
__init__.py # Lazy imports
modular_pipeline.py # Pipeline class (tiny, mostly config)
encoders.py # Text encoder + image/video VAE encoder blocks
before_denoise.py # Pre-denoise setup blocks
denoise.py # The denoising loop blocks
decoders.py # VAE decode block
modular_blocks_<model>.py # Block assembly (AutoBlocks)
```
## Block types decision tree
```
Is this a single operation?
YES -> ModularPipelineBlocks (leaf block)
Does it run multiple blocks in sequence?
YES -> SequentialPipelineBlocks
Does it iterate (e.g. chunk loop)?
YES -> LoopSequentialPipelineBlocks
Does it choose ONE block based on which input is present?
Is the selection 1:1 with trigger inputs?
YES -> AutoPipelineBlocks (simple trigger mapping)
NO -> ConditionalPipelineBlocks (custom select_block method)
```
## Build order (easiest first)
1. `decoders.py` -- Takes latents, runs VAE decode, returns images/videos
2. `encoders.py` -- Takes prompt, returns prompt_embeds. Add image/video VAE encoder if needed
3. `before_denoise.py` -- Timesteps, latent prep, noise setup. Each logical operation = one block
4. `denoise.py` -- The hardest. Convert guidance to guider abstraction
## Key pattern: Guider abstraction
Original pipeline has guidance baked in:
```python
for i, t in enumerate(timesteps):
noise_pred = self.transformer(latents, prompt_embeds, ...)
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(latents, negative_prompt_embeds, ...)
noise_pred = noise_uncond + scale * (noise_pred - noise_uncond)
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
```
Modular pipeline separates concerns:
```python
guider_inputs = {
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
}
for i, t in enumerate(timesteps):
components.guider.set_state(step=i, num_inference_steps=num_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
for batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = {k: getattr(batch, k) for k in guider_inputs}
context_name = getattr(batch, components.guider._identifier_key)
with components.transformer.cache_context(context_name):
batch.noise_pred = components.transformer(
hidden_states=latents, timestep=timestep,
return_dict=False, **cond_kwargs, **shared_kwargs,
)[0]
components.guider.cleanup_models(components.transformer)
noise_pred = components.guider(guider_state)[0]
latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0]
```
## Key pattern: Chunk loops for video models
Use `LoopSequentialPipelineBlocks` for outer loop:
```python
class ChunkDenoiseStep(LoopSequentialPipelineBlocks):
block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep]
```
Note: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index.
## Key pattern: Workflow selection
```python
class AutoDenoise(ConditionalPipelineBlocks):
block_classes = [V2VDenoiseStep, I2VDenoiseStep, T2VDenoiseStep]
block_trigger_inputs = ["video_latents", "image_latents"]
default_block_name = "text2video"
```
## Standard InputParam/OutputParam templates
```python
# Inputs
InputParam.template("prompt") # str, required
InputParam.template("negative_prompt") # str, optional
InputParam.template("image") # PIL.Image, optional
InputParam.template("generator") # torch.Generator, optional
InputParam.template("num_inference_steps") # int, default=50
InputParam.template("latents") # torch.Tensor, optional
# Outputs
OutputParam.template("prompt_embeds")
OutputParam.template("negative_prompt_embeds")
OutputParam.template("image_latents")
OutputParam.template("latents")
OutputParam.template("videos")
OutputParam.template("images")
```
## ComponentSpec patterns
```python
# Heavy models - loaded from pretrained
ComponentSpec("transformer", YourTransformerModel)
ComponentSpec("vae", AutoencoderKL)
# Lightweight objects - created inline from config
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 7.5}),
default_creation_method="from_config"
)
```
## Conversion checklist
- [ ] Read original pipeline's `__call__` end-to-end, map stages
- [ ] Write test scripts (reference + target) with identical seeds
- [ ] Create file structure under `modular_pipelines/<model>/`
- [ ] Write decoder block (simplest)
- [ ] Write encoder blocks (text, image, video)
- [ ] Write before_denoise blocks (timesteps, latent prep, noise)
- [ ] Write denoise block with guider abstraction (hardest)
- [ ] Create pipeline class with `default_blocks_name`
- [ ] Assemble blocks in `modular_blocks_<model>.py`
- [ ] Wire up `__init__.py` with lazy imports
- [ ] Run `make style` and `make quality`
- [ ] Test all workflows for parity with reference

View File

@@ -0,0 +1,170 @@
---
name: testing-parity
description: >
Use when debugging or verifying numerical parity between pipeline
implementations (e.g., research repo vs diffusers, standard vs modular).
Also relevant when outputs look wrong — washed out, pixelated, or have
visual artifacts — as these are usually parity bugs.
---
## Setup — gather before starting
Before writing any test code, gather:
1. **Which two implementations** are being compared (e.g. research repo → diffusers, standard → modular, or research → modular). Use `AskUserQuestion` with structured choices if not already clear.
2. **Two equivalent runnable scripts** — one for each implementation, both expected to produce identical output given the same inputs. These scripts define what "parity" means concretely.
When invoked from the `model-integration` skill, you already have context: the reference script comes from step 2 of setup, and the diffusers script is the one you just wrote. You just need to make sure both scripts are runnable and use the same inputs/seed/params.
## Test strategy
**Component parity (CPU/float32) -- always run, as you build.**
Test each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3.
Test freshly converted checkpoints and saved checkpoints.
- **Fresh**: convert from checkpoint weights, compare against reference (catches conversion bugs)
- **Saved**: load from saved model on disk, compare against reference (catches stale saves)
Keep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values.
Template -- one self-contained script per component, reference and diffusers side-by-side:
```python
@torch.inference_mode()
def test_my_component(mode="fresh", model_path=None):
# 1. Deterministic input
gen = torch.Generator().manual_seed(42)
x = torch.randn(1, 3, 64, 64, generator=gen, dtype=torch.float32)
# 2. Reference: load from checkpoint, run, free
ref_model = ReferenceModel.from_config(config)
ref_model.load_state_dict(load_weights("prefix"), strict=True)
ref_model = ref_model.float().eval()
ref_out = ref_model(x).clone()
del ref_model
# 3. Diffusers: fresh (convert weights) or saved (from_pretrained)
if mode == "fresh":
diff_model = convert_my_component(load_weights("prefix"))
else:
diff_model = DiffusersModel.from_pretrained(model_path, torch_dtype=torch.float32)
diff_model = diff_model.float().eval()
diff_out = diff_model(x)
del diff_model
# 4. Compare in same script -- no saving to disk
max_diff = (ref_out - diff_out).abs().max().item()
assert max_diff < 1e-3, f"FAIL: max_diff={max_diff:.2e}"
```
Key points: (a) both reference and diffusers component in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model.
**E2E visual (GPU/bfloat16) -- once the pipeline is assembled.**
Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, you're done -- no need for deeper testing.
**Pipeline stage tests -- only if E2E fails and you need to isolate the bug.**
If the user already suspects where divergence is, start there. Otherwise, work through stages in order.
First, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match. Check how noise is initialized in the diffusers script — if it doesn't match the reference, temporarily change it to match. Note what you changed so it can be reverted after parity is confirmed.
For small models, run on CPU/float32 for strict comparison. For large models (e.g. 22B params), CPU/float32 is impractical -- use GPU/bfloat16 with `enable_model_cpu_offload()` and relax tolerances (max_diff < 1e-1 for bfloat16 is typical for passing tests; cosine similarity > 0.9999 is a good secondary check).
Test encode and decode stages first -- they're simpler and bugs there are easier to fix. Only debug the denoising loop if encode and decode both pass.
The challenge: pipelines are monolithic `__call__` methods -- you can't just call "the encode part". See [checkpoint-mechanism.md](checkpoint-mechanism.md) for the checkpoint class that lets you stop, save, or inject tensors at named locations inside the pipeline.
**Stage test order — encode, decode, then denoise:**
- **`encode`** (test first): Stop both pipelines at `"preloop"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs.
- **`decode`** (test second, before denoise): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the decoded output. Then feed those same post-loop latents through the diffusers pipeline's decode path. Compare both numerically AND visually.
- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps, but stop after 2 loop iterations using `after_step_1`. Don't set `num_steps=2` -- that produces unrealistic sigma schedules.
```python
# Encode stage -- stop before the loop, compare ALL inputs:
ref_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
run_reference_pipeline(ref_ckpts)
ref_data = ref_ckpts["preloop"].data
diff_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
run_diffusers_pipeline(diff_ckpts)
diff_data = diff_ckpts["preloop"].data
# Compare EVERY variable consumed by the denoise loop:
compare_tensors("latents", ref_data["latents"], diff_data["latents"])
compare_tensors("sigmas", ref_data["sigmas"], diff_data["sigmas"])
compare_tensors("prompt_embeds", ref_data["prompt_embeds"], diff_data["prompt_embeds"])
# ... every single tensor the transformer forward() will receive
```
**E2E-injected visual test**: Once you've identified a suspected root cause using stage tests, confirm it with an e2e-injected run -- inject the known-good tensor from reference and generate a full video. If the output looks identical to reference, you've confirmed the root cause.
## Debugging technique: Injection for root-cause isolation
When stage tests show divergence, **inject a known-good tensor from one pipeline into the other** to test whether the remaining code is correct.
The principle: if you suspect input X is the root cause of divergence in stage S:
1. Run the reference pipeline and capture X
2. Run the diffusers pipeline but **replace** its X with the reference's X (via checkpoint load)
3. Compare outputs of stage S
If outputs now match: X was the root cause. If they still diverge: the bug is in the stage logic itself, not in X.
| What you're testing | What you inject | Where you inject |
|---|---|---|
| Is the decode stage correct? | Post-loop latents from reference | Before decode |
| Is the denoise loop correct? | Pre-loop latents from reference | Before the loop |
| Is step N correct? | Post-step-(N-1) latents from reference | Before step N |
**Per-step accumulation tracing**: When injection confirms the loop is correct but you want to understand *how* a small initial difference compounds, capture `after_step_{i}` for every step and plot the max_diff curve. A healthy curve stays bounded; an exponential blowup in later steps points to an amplification mechanism (see Pitfall #13 in [pitfalls.md](pitfalls.md)).
## Debugging technique: Visual comparison via frame extraction
For video pipelines, numerical metrics alone can be misleading. Extract and view individual frames:
```python
import numpy as np
from PIL import Image
def extract_frames(video_np, frame_indices):
"""video_np: (frames, H, W, 3) float array in [0, 1]"""
for idx in frame_indices:
frame = (video_np[idx] * 255).clip(0, 255).astype(np.uint8)
img = Image.fromarray(frame)
img.save(f"frame_{idx}.png")
# Compare specific frames from both pipelines
extract_frames(ref_video, [0, 60, 120])
extract_frames(diff_video, [0, 60, 120])
```
## Testing rules
1. **Never use reference code in the diffusers test path.** Each side must use only its own code.
2. **Never monkey-patch model internals in tests.** Do not replace `model.forward` or patch internal methods.
3. **Debugging instrumentation must be non-destructive.** Checkpoint captures for debugging are fine, but must not alter control flow or outputs.
4. **Prefer CPU/float32 for numerical comparison when practical.** Float32 avoids bfloat16 precision noise that obscures real bugs. But for large models (22B+), GPU/bfloat16 with `enable_model_cpu_offload()` is necessary -- use relaxed tolerances and cosine similarity as a secondary metric.
5. **Test both fresh conversion AND saved model.** Fresh catches conversion logic bugs; saved catches stale/corrupted weights from previous runs.
6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values. A 30-second config diff prevents hours of debugging based on wrong assumptions.
7. **Never modify cached/downloaded model configs directly.** Don't edit files in `~/.cache/huggingface/`. Instead, save to a local directory or open a PR on the upstream repo.
8. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive.
## Comparison utilities
```python
def compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e-3) -> bool:
if a.shape != b.shape:
print(f" FAIL {name}: shape mismatch {a.shape} vs {b.shape}")
return False
diff = (a.float() - b.float()).abs()
max_diff = diff.max().item()
mean_diff = diff.mean().item()
cos = torch.nn.functional.cosine_similarity(
a.float().flatten().unsqueeze(0), b.float().flatten().unsqueeze(0)
).item()
passed = max_diff < tol
print(f" {'PASS' if passed else 'FAIL'} {name}: max={max_diff:.2e}, mean={mean_diff:.2e}, cos={cos:.5f}")
return passed
```
Cosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance.
## Gotchas
See [pitfalls.md](pitfalls.md) for the full list of gotchas to watch for during parity testing.

View File

@@ -0,0 +1,103 @@
# Checkpoint Mechanism for Stage Testing
## Overview
Pipelines are monolithic `__call__` methods -- you can't just call "the encode part". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline.
## The Checkpoint class
Add a `_checkpoints` argument to both the diffusers pipeline and the reference implementation.
```python
@dataclass
class Checkpoint:
save: bool = False # capture variables into ckpt.data
stop: bool = False # halt pipeline after this point
load: bool = False # inject ckpt.data into local variables
data: dict = field(default_factory=dict)
```
## Pipeline instrumentation
The pipeline accepts an optional `dict[str, Checkpoint]`. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode).
```python
def __call__(self, prompt, ..., _checkpoints=None):
# --- text encoding ---
prompt_embeds = self.text_encoder(prompt)
_maybe_checkpoint(_checkpoints, "text_encoding", {
"prompt_embeds": prompt_embeds,
})
# --- prepare latents, sigmas, positions ---
latents = self.prepare_latents(...)
sigmas = self.scheduler.sigmas
# ...
_maybe_checkpoint(_checkpoints, "preloop", {
"latents": latents,
"sigmas": sigmas,
"prompt_embeds": prompt_embeds,
"prompt_attention_mask": prompt_attention_mask,
"video_coords": video_coords,
# capture EVERYTHING the loop needs -- every tensor the transformer
# forward() receives. Missing even one variable here means you can't
# tell if it's the source of divergence during denoise debugging.
})
# --- denoising loop ---
for i, t in enumerate(timesteps):
noise_pred = self.transformer(latents, t, prompt_embeds, ...)
latents = self.scheduler.step(noise_pred, t, latents)[0]
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
"latents": latents,
})
_maybe_checkpoint(_checkpoints, "post_loop", {
"latents": latents,
})
# --- decode ---
video = self.vae.decode(latents)
return video
```
## The helper function
Each `_maybe_checkpoint` call does three things based on the Checkpoint's flags: `save` captures the local variables into `ckpt.data`, `load` injects pre-populated `ckpt.data` back into local variables, `stop` halts execution (raises an exception caught at the top level).
```python
def _maybe_checkpoint(checkpoints, name, data):
if not checkpoints:
return
ckpt = checkpoints.get(name)
if ckpt is None:
return
if ckpt.save:
ckpt.data.update(data)
if ckpt.stop:
raise PipelineStop # caught at __call__ level, returns None
```
## Injection support
Add `load` support at each checkpoint where you might want to inject:
```python
_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...})
# Load support: replace local variables with injected data
if _checkpoints:
ckpt = _checkpoints.get("preloop")
if ckpt is not None and ckpt.load:
latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype)
```
## Key insight
The checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back `ckpt.data` to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference `"video_state.latent"` -> diffusers `"latents"`).
## Memory management for large models
For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with `enable_model_cpu_offload()`.

View File

@@ -0,0 +1,116 @@
# Complete Pitfalls Reference
## 1. Global CPU RNG
`MultivariateNormal.sample()` uses the global CPU RNG, not `torch.Generator`. Must call `torch.manual_seed(seed)` before each pipeline run. A `generator=` kwarg won't help.
## 2. Timestep dtype
Many transformers expect `int64` timesteps. `get_timestep_embedding` casts to float, so `745.3` and `745` produce different embeddings. Match the reference's casting.
## 3. Guidance parameter mapping
Parameter names may differ: reference `zero_steps=1` (meaning `i <= 1`, 2 steps) vs target `zero_init_steps=2` (meaning `step < 2`, same thing). Check exact semantics.
## 4. `patch_size` in noise generation
If noise generation depends on `patch_size` (e.g. `sample_block_noise`), it must be passed through. Missing it changes noise spatial structure.
## 5. Variable shadowing in nested loops
Nested loops (stages -> chunks -> timesteps) can shadow variable names. If outer loop uses `latents` and inner loop also assigns to `latents`, scoping must match the reference.
## 6. Float precision differences -- don't dismiss them
Target may compute in float32 where reference used bfloat16. Small per-element diffs (1e-3 to 1e-2) *look* harmless but can compound catastrophically over iterative processes like denoising loops (see Pitfalls #11 and #13). Before dismissing a precision difference: (a) check whether it feeds into an iterative process, (b) if so, trace the accumulation curve over all iterations to see if it stays bounded or grows exponentially. Only truly non-iterative precision diffs (e.g. in a single-pass encoder) are safe to accept.
## 7. Scheduler state reset between stages
Some schedulers accumulate state (e.g. `model_outputs` in UniPC) that must be cleared between stages.
## 8. Component access
Standard: `self.transformer`. Modular: `components.transformer`. Missing this causes AttributeError.
## 9. Guider state across stages
In multi-stage denoising, the guider's internal state (e.g. `zero_init_steps`) may need save/restore between stages.
## 10. Model storage location
NEVER store converted models in `/tmp/` -- temporary directories get wiped on restart. Always save converted checkpoints under a persistent path in the project repo (e.g. `models/ltx23-diffusers/`).
## 11. Noise dtype mismatch (causes washed-out output)
Reference code often generates noise in float32 then casts to model dtype (bfloat16) before storing:
```python
noise = torch.randn(..., dtype=torch.float32, generator=gen)
noise = noise.to(dtype=model_dtype) # bfloat16 -- values get quantized
```
Diffusers pipelines may keep latents in float32 throughout the loop. The per-element difference is only ~1.5e-02, but this compounds over 30 denoising steps via 1/sigma amplification (Pitfall #13) and produces completely washed-out output.
**Fix**: Match the reference -- generate noise in the model's working dtype:
```python
latent_dtype = self.transformer.dtype # e.g. bfloat16
latents = self.prepare_latents(..., dtype=latent_dtype, ...)
```
**Detection**: Encode stage test shows initial latent max_diff of exactly ~1.5e-02. This specific magnitude is the signature of float32->bfloat16 quantization error.
## 12. RoPE position dtype
RoPE cosine/sine values are sensitive to position coordinate dtype. If reference uses bfloat16 positions but diffusers uses float32, the RoPE output diverges significantly (max_diff up to 2.0). Different modalities may use different position dtypes (e.g. video bfloat16, audio float32) -- check the reference carefully.
## 13. 1/sigma error amplification in Euler denoising
In Euler/flow-matching, the velocity formula divides by sigma: `v = (latents - pred_x0) / sigma`. As sigma shrinks from ~1.0 (step 0) to ~0.001 (step 29), errors are amplified up to 1000x. A 1.5e-02 init difference grows linearly through mid-steps, then exponentially in final steps, reaching max_diff ~6.0. This is why dtype mismatches (Pitfalls #11, #12) that seem tiny at init produce visually broken output. Use per-step accumulation tracing to diagnose.
## 14. Config value assumptions -- always diff, never assume
When debugging parity, don't assume config values match code defaults. The published model checkpoint may override defaults with different values. A wrong assumption about a single config field can send you down hours of debugging in the wrong direction.
**The pattern that goes wrong:**
1. You see `param_x` has default `1` in the code
2. The reference code also uses `param_x` with a default of `1`
3. You assume both sides use `1` and apply a "fix" based on that
4. But the actual checkpoint config has `param_x: 1000`, and so does the published diffusers config
5. Your "fix" now *creates* divergence instead of fixing it
**Prevention -- config diff first:**
```python
# Reference: read from checkpoint metadata (no model loading needed)
from safetensors import safe_open
import json
ref_config = json.loads(safe_open(checkpoint_path, framework="pt").metadata()["config"])
# Diffusers: read from model config
from diffusers import MyModel
diff_model = MyModel.from_pretrained(model_path, subfolder="transformer")
diff_config = dict(diff_model.config)
# Compare all values
for key in sorted(set(list(ref_config.get("transformer", {}).keys()) + list(diff_config.keys()))):
ref_val = ref_config.get("transformer", {}).get(key, "MISSING")
diff_val = diff_config.get(key, "MISSING")
if ref_val != diff_val:
print(f" DIFF {key}: ref={ref_val}, diff={diff_val}")
```
Run this **before** writing any hooks, analysis code, or fixes. It takes 30 seconds and catches wrong assumptions immediately.
**When debugging divergence -- trace values, don't reason about them:**
If two implementations diverge, hook the actual intermediate values at the point of divergence rather than reading code to figure out what the values "should" be. Code analysis builds on assumptions; value tracing reveals facts.
## 15. Decoder config mismatch (causes pixelated artifacts)
The upstream model config may have wrong values for decoder-specific parameters (e.g. `upsample_residual`, `upsample_type`). These control whether the decoder uses skip connections in upsampling -- getting them wrong produces severe pixelation or blocky artifacts.
**Detection**: Feed identical post-loop latents through both decoders. If max pixel diff is large (PSNR < 40 dB) on CPU/float32, it's a real bug, not precision noise. Trace through decoder blocks (conv_in -> mid_block -> up_blocks) to find where divergence starts.
**Fix**: Correct the config value. Don't edit cached files in `~/.cache/huggingface/` -- either save to a local model directory or open a PR on the upstream repo (see Testing Rule #7).
## 16. Incomplete injection tests -- inject ALL variables or the test is invalid
When doing injection tests (feeding reference tensors into the diffusers pipeline), you must inject **every** divergent input, including sigmas/timesteps. A common mistake: the preloop checkpoint saves sigmas but the injection code only loads latents and embeddings. The test then runs with different sigma schedules, making it impossible to isolate the real cause.
**Prevention**: After writing injection code, verify by listing every variable the injected stage consumes and checking each one is either (a) injected from reference, or (b) confirmed identical between pipelines.
## 17. bf16 connector/encoder divergence -- don't chase it
When running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector transformers) accumulate bf16 rounding noise that looks alarming (max_diff 0.3-2.7). Before investigating, re-run the component test on CPU/float32. If it passes (max_diff < 1e-4), the divergence is pure precision noise, not a code bug. Don't spend hours tracing through layers -- confirm on CPU/float32 and move on.
## 18. Stale test fixtures
When using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution.

View File

@@ -4,6 +4,7 @@
name: (Release) Fast GPU Tests on main
on:
workflow_dispatch:
push:
branches:
- "v*.*.*-release"
@@ -33,6 +34,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -74,6 +76,7 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -125,6 +128,7 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -175,6 +179,7 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -232,6 +237,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -274,6 +280,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -316,6 +323,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |

4
.gitignore vendored
View File

@@ -182,4 +182,6 @@ wandb
# AI agent generated symlinks
/AGENTS.md
/CLAUDE.md
/CLAUDE.md
/.agents/skills
/.claude/skills

View File

@@ -103,9 +103,16 @@ post-patch:
codex:
ln -snf .ai/AGENTS.md AGENTS.md
mkdir -p .agents
rm -rf .agents/skills
ln -snf ../.ai/skills .agents/skills
claude:
ln -snf .ai/AGENTS.md CLAUDE.md
mkdir -p .claude
rm -rf .claude/skills
ln -snf ../.ai/skills .claude/skills
clean-ai:
rm -f AGENTS.md CLAUDE.md
rm -rf .agents/skills .claude/skills

View File

@@ -572,9 +572,9 @@ For documentation strings, 🧨 Diffusers follows the [Google style](https://goo
The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.
- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`)
- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks
- **Source of truth** — edit files under `.ai/` (`AGENTS.md` for coding guidelines, `skills/` for on-demand task knowledge)
- **Don't edit** generated root-level `AGENTS.md`, `CLAUDE.md`, or `.agents/skills`/`.claude/skills` — they are symlinks
- Setup commands:
- `make codex` — symlink for OpenAI Codex
- `make claude` — symlink for Claude Code
- `make clean-ai` — remove generated symlinks
- `make codex` — symlink guidelines + skills for OpenAI Codex
- `make claude` — symlink guidelines + skills for Claude Code
- `make clean-ai` — remove all generated symlinks

View File

@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -7,7 +7,7 @@ import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Processor
from diffusers import (
AutoencoderKLLTX2Audio,
@@ -17,7 +17,7 @@ from diffusers import (
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE
from diffusers.utils.import_utils import is_accelerate_available
@@ -44,6 +44,12 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
"k_norm": "norm_k",
}
LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = {
**LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT,
"audio_prompt_adaln_single": "audio_prompt_adaln",
"prompt_adaln_single": "prompt_adaln",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
@@ -72,6 +78,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
# Decoder extra blocks
"up_blocks.7": "up_blocks.3.upsamplers.0",
"up_blocks.8": "up_blocks.3",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
@@ -84,10 +97,34 @@ LTX_2_0_VOCODER_RENAME_DICT = {
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
LTX_2_3_VOCODER_RENAME_DICT = {
# Handle upsamplers ("ups" --> "upsamplers") due to name clash
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
"act_post": "act_out",
"downsample.lowpass": "downsample",
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# LTX-2.3 uses per-modality embedding projections
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
@@ -129,23 +166,24 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str
return
def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if ".ups." in key:
new_key = key.replace(".ups.", ".upsamplers.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
@@ -155,13 +193,19 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = {
".ups.": convert_ltx2_3_vocoder_upsamplers,
}
LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"text_embedding_projection",
"connectors.",
"video_connector",
"audio_connector",
@@ -225,7 +269,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"model_id": "Lightricks/LTX-2",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
@@ -238,6 +282,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"gated_attn": False,
"cross_attn_mod": False,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
@@ -249,6 +295,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"audio_gated_attn": False,
"audio_cross_attn_mod": False,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
@@ -263,10 +311,62 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
"use_prompt_embeddings": True,
"perturbed_attn": False,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.3":
config = {
"model_id": "Lightricks/LTX-2.3",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"gated_attn": True,
"cross_attn_mod": True,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"audio_gated_attn": True,
"audio_cross_attn_mod": True,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
"use_prompt_embeddings": False,
"perturbed_attn": True,
},
}
rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
@@ -293,7 +393,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"model_id": "Lightricks/LTX-2",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
@@ -301,20 +401,52 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"video_gated_attn": False,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"audio_gated_attn": False,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
"per_modality_projections": False,
"proj_bias": False,
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
elif version == "2.3":
config = {
"model_id": "Lightricks/LTX-2.3",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 32,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 8,
"video_connector_num_learnable_registers": 128,
"video_gated_attn": True,
"audio_connector_num_attention_heads": 32,
"audio_connector_attention_head_dim": 64,
"audio_connector_num_layers": 8,
"audio_connector_num_learnable_registers": 128,
"audio_gated_attn": True,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
"per_modality_projections": True,
"video_hidden_dim": 4096,
"audio_hidden_dim": 2048,
"proj_bias": True,
},
}
rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
@@ -416,7 +548,7 @@ def get_ltx2_video_vae_config(
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"model_id": "Lightricks/LTX-2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
@@ -435,6 +567,7 @@ def get_ltx2_video_vae_config(
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": timestep_conditioning,
@@ -451,6 +584,44 @@ def get_ltx2_video_vae_config(
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.3":
config = {
"model_id": "Lightricks/LTX-2.3",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 1024),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 512, 1024),
"layers_per_block": (4, 6, 4, 2, 2),
"decoder_layers_per_block": (4, 6, 4, 2, 2),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True, True),
"decoder_inject_noise": (False, False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
"upsample_residual": (False, False, False, False),
"upsample_factor": (2, 2, 1, 2),
"timestep_conditioning": timestep_conditioning,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "zeros",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
@@ -485,7 +656,7 @@ def convert_ltx2_video_vae(
def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"model_id": "Lightricks/LTX-2",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
@@ -508,6 +679,31 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
elif version == "2.3":
config = {
"model_id": "Lightricks/LTX-2.3",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
}, # Same config as LTX-2.0
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
@@ -540,7 +736,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) ->
def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"model_id": "Lightricks/LTX-2",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
@@ -549,21 +745,71 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"act_fn": "leaky_relu",
"leaky_relu_negative_slope": 0.1,
"antialias": False,
"final_act_fn": "tanh",
"final_bias": True,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
elif version == "2.3":
config = {
"model_id": "Lightricks/LTX-2.3",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1536,
"out_channels": 2,
"upsample_kernel_sizes": [11, 4, 4, 4, 4, 4],
"upsample_factors": [5, 2, 2, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"act_fn": "snakebeta",
"leaky_relu_negative_slope": 0.1,
"antialias": True,
"antialias_ratio": 2,
"antialias_kernel_size": 12,
"final_act_fn": None,
"final_bias": False,
"bwe_in_channels": 128,
"bwe_hidden_channels": 512,
"bwe_out_channels": 2,
"bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4],
"bwe_upsample_factors": [6, 5, 2, 2, 2],
"bwe_resnet_kernel_sizes": [3, 7, 11],
"bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"bwe_act_fn": "snakebeta",
"bwe_leaky_relu_negative_slope": 0.1,
"bwe_antialias": True,
"bwe_antialias_ratio": 2,
"bwe_antialias_kernel_size": 12,
"bwe_final_act_fn": None,
"bwe_final_bias": False,
"filter_length": 512,
"hop_length": 80,
"window_length": 512,
"num_mel_channels": 64,
"input_sampling_rate": 16000,
"output_sampling_rate": 48000,
},
}
rename_dict = LTX_2_3_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
if version == "2.3":
vocoder_cls = LTX2VocoderWithBWE
else:
vocoder_cls = LTX2Vocoder
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
vocoder = vocoder_cls.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
@@ -594,6 +840,18 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
"use_rational_resampler": True,
}
elif version == "2.3":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
"use_rational_resampler": False,
}
else:
raise ValueError(f"Unsupported version: {version}")
@@ -651,13 +909,17 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
model_state_dict[param_name.removeprefix(prefix)] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
connector_prefixes = ["text_embedding_projection"]
for param_name, param in combined_ckpt.items():
for prefix in connector_prefixes:
if param_name.startswith(prefix):
# Check to make sure we're not overwriting an existing key
if param_name not in model_state_dict:
model_state_dict[param_name] = combined_ckpt[param_name]
return model_state_dict
@@ -686,7 +948,7 @@ def get_args():
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
choices=["test", "2.0", "2.3"],
help="Version of the LTX 2.0 model",
)
@@ -748,6 +1010,11 @@ def get_args():
action="store_true",
help="Whether to save a latent upsampling pipeline",
)
parser.add_argument(
"--add_processor",
action="store_true",
help="Whether to add a Gemma3Processor to the pipeline for prompt enhancement.",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
@@ -756,6 +1023,12 @@ def get_args():
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument(
"--upsample_output_path",
type=str,
default=None,
help="Path where converted upsampling pipeline should be saved",
)
return parser.parse_args()
@@ -787,7 +1060,7 @@ def main(args):
args.audio_vae,
args.dit,
args.vocoder,
args.text_encoder,
args.connectors,
args.full_pipeline,
args.upsample_pipeline,
]
@@ -852,7 +1125,12 @@ def main(args):
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
if args.add_processor:
processor = Gemma3Processor.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
processor.save_pretrained(os.path.join(args.output_path, "processor"))
if args.latent_upsampler or args.upsample_pipeline:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
@@ -866,14 +1144,26 @@ def main(args):
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
is_distilled_ckpt = "distilled" in args.combined_filename
if is_distilled_ckpt:
# Disable dynamic shifting and terminal shift so that distilled sigmas are used as-is
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=False,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=None,
)
else:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
@@ -891,10 +1181,12 @@ def main(args):
if args.upsample_pipeline:
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
pipe.save_pretrained(
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
)
# As two diffusers pipelines cannot be in the same directory, save the upsampling pipeline to its own directory
if args.upsample_output_path:
upsample_output_path = args.upsample_output_path
else:
upsample_output_path = args.output_path
pipe.save_pretrained(upsample_output_path, safe_serialization=True, max_shard_size="5GB")
if __name__ == "__main__":

View File

@@ -12,6 +12,7 @@ from termcolor import colored
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import (
AutoencoderKLLTX2Video,
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available else nullcontext
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
@@ -92,12 +96,22 @@ def main(args):
if args.video_size == 480:
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
patch_size = (1, 2, 2)
in_channels = 16
out_channels = 16
elif args.video_size == 720:
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
patch_size = (1, 1, 1)
in_channels = 32
out_channels = 32
else:
raise ValueError(f"Video size {args.video_size} is not supported.")
if args.vae_type == "ltx2":
sample_size = 22
patch_size = (1, 1, 1)
in_channels = 128
out_channels = 128
for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
@@ -182,8 +196,8 @@ def main(args):
# Transformer
with CTX():
transformer_kwargs = {
"in_channels": 16,
"out_channels": 16,
"in_channels": in_channels,
"out_channels": out_channels,
"num_attention_heads": 20,
"attention_head_dim": 112,
"num_layers": 20,
@@ -235,9 +249,12 @@ def main(args):
else:
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
# VAE
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
if args.vae_type == "ltx2":
vae_path = args.vae_path or "Lightricks/LTX-2"
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
else:
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
# Text Encoder
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
@@ -314,7 +331,23 @@ if __name__ == "__main__":
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument(
"--vae_type",
default="wan",
type=str,
choices=["wan", "ltx2"],
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
)
parser.add_argument(
"--vae_path",
default=None,
type=str,
required=False,
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
)
parser.add_argument(
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")

View File

@@ -2156,6 +2156,9 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
"q_norm": "norm_q",
"k_norm": "norm_k",
# LTX-2.3
"audio_prompt_adaln_single": "audio_prompt_adaln",
"prompt_adaln_single": "prompt_adaln",
}
else:
rename_dict = {"aggregate_embed": "text_proj_in"}

View File

@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -237,7 +237,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
class LTXVideoDownsampler3d(nn.Module):
class LTX2VideoDownsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
@@ -285,10 +285,11 @@ class LTXVideoDownsampler3d(nn.Module):
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
class LTXVideoUpsampler3d(nn.Module):
class LTX2VideoUpsampler3d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int | None = None,
stride: int | tuple[int, int, int] = 1,
residual: bool = False,
upscale_factor: int = 1,
@@ -300,7 +301,8 @@ class LTXVideoUpsampler3d(nn.Module):
self.residual = residual
self.upscale_factor = upscale_factor
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
out_channels = out_channels or in_channels
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
self.conv = LTX2VideoCausalConv3d(
in_channels=in_channels,
@@ -408,7 +410,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "spatial":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(1, 2, 2),
@@ -417,7 +419,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "temporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 1, 1),
@@ -426,7 +428,7 @@ class LTX2VideoDownBlock3D(nn.Module):
)
elif downsample_type == "spatiotemporal":
self.downsamplers.append(
LTXVideoDownsampler3d(
LTX2VideoDownsampler3d(
in_channels=in_channels,
out_channels=out_channels,
stride=(2, 2, 2),
@@ -580,6 +582,7 @@ class LTX2VideoUpBlock3d(nn.Module):
resnet_eps: float = 1e-6,
resnet_act_fn: str = "swish",
spatio_temporal_scale: bool = True,
upsample_type: str = "spatiotemporal",
inject_noise: bool = False,
timestep_conditioning: bool = False,
upsample_residual: bool = False,
@@ -609,16 +612,23 @@ class LTX2VideoUpBlock3d(nn.Module):
self.upsamplers = None
if spatio_temporal_scale:
self.upsamplers = nn.ModuleList(
[
LTXVideoUpsampler3d(
out_channels * upscale_factor,
stride=(2, 2, 2),
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
]
self.upsamplers = nn.ModuleList()
if upsample_type == "spatial":
upsample_stride = (1, 2, 2)
elif upsample_type == "temporal":
upsample_stride = (2, 1, 1)
elif upsample_type == "spatiotemporal":
upsample_stride = (2, 2, 2)
self.upsamplers.append(
LTX2VideoUpsampler3d(
in_channels=out_channels * upscale_factor,
stride=upsample_stride,
residual=upsample_residual,
upscale_factor=upscale_factor,
spatial_padding_mode=spatial_padding_mode,
)
)
resnets = []
@@ -716,7 +726,7 @@ class LTX2VideoEncoder3d(nn.Module):
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
@@ -726,6 +736,9 @@ class LTX2VideoEncoder3d(nn.Module):
spatial_padding_mode: str = "zeros",
):
super().__init__()
num_encoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
self.patch_size = patch_size
self.patch_size_t = patch_size_t
@@ -860,19 +873,27 @@ class LTX2VideoDecoder3d(nn.Module):
in_channels: int = 128,
out_channels: int = 3,
block_out_channels: tuple[int, ...] = (256, 512, 1024),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
patch_size: int = 4,
patch_size_t: int = 1,
resnet_norm_eps: float = 1e-6,
is_causal: bool = False,
inject_noise: tuple[bool, ...] = (False, False, False),
inject_noise: bool | tuple[bool, ...] = (False, False, False),
timestep_conditioning: bool = False,
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[bool, ...] = (2, 2, 2),
spatial_padding_mode: str = "reflect",
) -> None:
super().__init__()
num_decoder_blocks = len(layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(inject_noise, bool):
inject_noise = (inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
self.patch_size = patch_size
self.patch_size_t = patch_size_t
@@ -917,6 +938,7 @@ class LTX2VideoDecoder3d(nn.Module):
num_layers=layers_per_block[i + 1],
resnet_eps=resnet_norm_eps,
spatio_temporal_scale=spatio_temporal_scaling[i],
upsample_type=upsample_type[i],
inject_noise=inject_noise[i + 1],
timestep_conditioning=timestep_conditioning,
upsample_residual=upsample_residual[i],
@@ -1058,11 +1080,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: tuple[bool, ...] = (True, True, True),
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
upsample_factor: tuple[int, ...] = (2, 2, 2),
timestep_conditioning: bool = False,
patch_size: int = 4,
@@ -1077,6 +1100,16 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
temporal_compression_ratio: int = None,
) -> None:
super().__init__()
num_encoder_blocks = len(layers_per_block)
num_decoder_blocks = len(decoder_layers_per_block)
if isinstance(spatio_temporal_scaling, bool):
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
if isinstance(decoder_spatio_temporal_scaling, bool):
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
if isinstance(decoder_inject_noise, bool):
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
if isinstance(upsample_residual, bool):
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
self.encoder = LTX2VideoEncoder3d(
in_channels=in_channels,
@@ -1098,6 +1131,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
block_out_channels=decoder_block_out_channels,
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
layers_per_block=decoder_layers_per_block,
upsample_type=upsample_type,
patch_size=patch_size,
patch_size_t=patch_size_t,
resnet_norm_eps=resnet_norm_eps,

View File

@@ -13,7 +13,6 @@
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
@@ -343,7 +342,6 @@ class HeliosRotaryPosEmbed(nn.Module):
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)

View File

@@ -178,6 +178,10 @@ class LTX2AudioVideoAttnProcessor:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@@ -212,6 +216,112 @@ class LTX2AudioVideoAttnProcessor:
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LTX2PerturbedAttnProcessor:
r"""
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if is_torch_version("<", "2.0"):
raise ValueError(
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
)
def __call__(
self,
attn: "LTX2Attention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.to_gate_logits is not None:
# Calculate gate logits on original hidden_states
gate_logits = attn.to_gate_logits(hidden_states)
value = attn.to_v(encoder_hidden_states)
if all_perturbed is None:
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
if all_perturbed:
# Skip attention, use the value projection value
hidden_states = value
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
if query_rotary_emb is not None:
if attn.rope_type == "interleaved":
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
key = apply_interleaved_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
elif attn.rope_type == "split":
query = apply_split_rotary_emb(query, query_rotary_emb)
key = apply_split_rotary_emb(
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(query.dtype)
if perturbation_mask is not None:
value = value.flatten(2, 3)
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
if attn.to_gate_logits is not None:
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
hidden_states = hidden_states * gates.unsqueeze(-1)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
@@ -224,7 +334,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
"""
_default_processor_cls = LTX2AudioVideoAttnProcessor
_available_processors = [LTX2AudioVideoAttnProcessor]
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
def __init__(
self,
@@ -240,6 +350,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
norm_eps: float = 1e-6,
norm_elementwise_affine: bool = True,
rope_type: str = "interleaved",
apply_gated_attention: bool = False,
processor=None,
):
super().__init__()
@@ -266,6 +377,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
self.to_out.append(torch.nn.Dropout(dropout))
if apply_gated_attention:
# Per head gate values
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
else:
self.to_gate_logits = None
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
@@ -321,6 +438,10 @@ class LTX2VideoTransformerBlock(nn.Module):
audio_num_attention_heads: int,
audio_attention_head_dim,
audio_cross_attention_dim: int,
video_gated_attn: bool = False,
video_cross_attn_adaln: bool = False,
audio_gated_attn: bool = False,
audio_cross_attn_adaln: bool = False,
qk_norm: str = "rms_norm_across_heads",
activation_fn: str = "gelu-approximate",
attention_bias: bool = True,
@@ -328,9 +449,16 @@ class LTX2VideoTransformerBlock(nn.Module):
eps: float = 1e-6,
elementwise_affine: bool = False,
rope_type: str = "interleaved",
perturbed_attn: bool = False,
):
super().__init__()
self.perturbed_attn = perturbed_attn
if perturbed_attn:
attn_processor_cls = LTX2PerturbedAttnProcessor
else:
attn_processor_cls = LTX2AudioVideoAttnProcessor
# 1. Self-Attention (video and audio)
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
self.attn1 = LTX2Attention(
@@ -343,6 +471,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -356,6 +486,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 2. Prompt Cross-Attention
@@ -370,6 +502,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
@@ -383,6 +517,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
@@ -398,6 +534,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=video_gated_attn,
processor=attn_processor_cls(),
)
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
@@ -412,6 +550,8 @@ class LTX2VideoTransformerBlock(nn.Module):
out_bias=attention_out_bias,
qk_norm=qk_norm,
rope_type=rope_type,
apply_gated_attention=audio_gated_attn,
processor=attn_processor_cls(),
)
# 4. Feedforward layers
@@ -422,14 +562,36 @@ class LTX2VideoTransformerBlock(nn.Module):
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
# 5. Per-Layer Modulation Parameters
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
self.video_cross_attn_adaln = video_cross_attn_adaln
self.audio_cross_attn_adaln = audio_cross_attn_adaln
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
# Prompt cross-attn (attn2) additional modulation params
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
if self.cross_attn_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
# Per-layer a2v, v2a Cross-Attention mod params
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
@staticmethod
def get_mod_params(
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
) -> tuple[torch.Tensor, ...]:
num_ada_params = scale_shift_table.shape[0]
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.shape[1], num_ada_params, -1
)
ada_params = ada_values.unbind(dim=2)
return ada_params
def forward(
self,
hidden_states: torch.Tensor,
@@ -442,143 +604,181 @@ class LTX2VideoTransformerBlock(nn.Module):
temb_ca_audio_scale_shift: torch.Tensor,
temb_ca_gate: torch.Tensor,
temb_ca_audio_gate: torch.Tensor,
temb_prompt: torch.Tensor | None = None,
temb_prompt_audio: torch.Tensor | None = None,
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
self_attention_mask: torch.Tensor | None = None,
audio_self_attention_mask: torch.Tensor | None = None,
a2v_cross_attention_mask: torch.Tensor | None = None,
v2a_cross_attention_mask: torch.Tensor | None = None,
use_a2v_cross_attention: bool = True,
use_v2a_cross_attention: bool = True,
perturbation_mask: torch.Tensor | None = None,
all_perturbed: bool | None = None,
) -> torch.Tensor:
batch_size = hidden_states.size(0)
# 1. Video and Audio Self-Attention
norm_hidden_states = self.norm1(hidden_states)
# 1.1. Video Self-Attention
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
if self.video_cross_attn_adaln:
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
num_ada_params = self.scale_shift_table.shape[0]
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
batch_size, temb.size(1), num_ada_params, -1
)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
attn_hidden_states = self.attn1(
hidden_states=norm_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=video_rotary_emb,
)
video_self_attn_args = {
"hidden_states": norm_hidden_states,
"encoder_hidden_states": None,
"query_rotary_emb": video_rotary_emb,
"attention_mask": self_attention_mask,
}
if self.perturbed_attn:
video_self_attn_args["perturbation_mask"] = perturbation_mask
video_self_attn_args["all_perturbed"] = all_perturbed
attn_hidden_states = self.attn1(**video_self_attn_args)
hidden_states = hidden_states + attn_hidden_states * gate_msa
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
batch_size, temb_audio.size(1), num_audio_ada_params, -1
)
# 1.2. Audio Self-Attention
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
audio_ada_values.unbind(dim=2)
audio_ada_params[:6]
)
if self.audio_cross_attn_adaln:
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
attn_audio_hidden_states = self.audio_attn1(
hidden_states=norm_audio_hidden_states,
encoder_hidden_states=None,
query_rotary_emb=audio_rotary_emb,
)
audio_self_attn_args = {
"hidden_states": norm_audio_hidden_states,
"encoder_hidden_states": None,
"query_rotary_emb": audio_rotary_emb,
"attention_mask": audio_self_attention_mask,
}
if self.perturbed_attn:
audio_self_attn_args["perturbation_mask"] = perturbation_mask
audio_self_attn_args["all_perturbed"] = all_perturbed
attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args)
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
# 2. Video and Audio Cross-Attention with the text embeddings
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
if self.cross_attn_adaln:
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
shift_text_kv, scale_text_kv = video_prompt_ada_params
audio_prompt_ada_params = self.get_mod_params(
self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
)
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text)
norm_hidden_states = self.norm2(hidden_states)
if self.video_cross_attn_adaln:
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
if self.cross_attn_adaln:
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
attn_hidden_states = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
query_rotary_emb=None,
attention_mask=encoder_attention_mask,
)
if self.video_cross_attn_adaln:
attn_hidden_states = attn_hidden_states * gate_text_q
hidden_states = hidden_states + attn_hidden_states
# 2.2. Audio-Text Cross-Attention
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
if self.audio_cross_attn_adaln:
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
if self.cross_attn_adaln:
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
attn_audio_hidden_states = self.audio_attn2(
norm_audio_hidden_states,
encoder_hidden_states=audio_encoder_hidden_states,
query_rotary_emb=None,
attention_mask=audio_encoder_attention_mask,
)
if self.audio_cross_attn_adaln:
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
if use_a2v_cross_attention or use_v2a_cross_attention:
norm_hidden_states = self.audio_to_video_norm(hidden_states)
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
# Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
# 3.1. Combine global and per-layer cross attention modulation parameters
# Video
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
video_ca_scale_shift_table = (
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
video_ca_gate = (
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
).unbind(dim=2)
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
a2v_gate = video_ca_gate[0].squeeze(2)
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
a2v_gate = video_ca_gate_param[0].squeeze(2)
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
# Audio
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
audio_ca_scale_shift_table = (
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
).unbind(dim=2)
audio_ca_gate = (
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
).unbind(dim=2)
audio_ca_ada_params = self.get_mod_params(
audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
)
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
v2a_gate = audio_ca_gate[0].squeeze(2)
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
v2a_gate = audio_ca_gate_param[0].squeeze(2)
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
2
)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_a2v_ca_scale.squeeze(2)
) + audio_a2v_ca_shift.squeeze(2)
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
if use_a2v_cross_attention:
mod_norm_hidden_states = norm_hidden_states * (
1 + video_a2v_ca_scale.squeeze(2)
) + video_a2v_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_a2v_ca_scale.squeeze(2)
) + audio_a2v_ca_shift.squeeze(2)
a2v_attn_hidden_states = self.audio_to_video_attn(
mod_norm_hidden_states,
encoder_hidden_states=mod_norm_audio_hidden_states,
query_rotary_emb=ca_video_rotary_emb,
key_rotary_emb=ca_audio_rotary_emb,
attention_mask=a2v_cross_attention_mask,
)
a2v_attn_hidden_states = self.audio_to_video_attn(
mod_norm_hidden_states,
encoder_hidden_states=mod_norm_audio_hidden_states,
query_rotary_emb=ca_video_rotary_emb,
key_rotary_emb=ca_audio_rotary_emb,
attention_mask=a2v_cross_attention_mask,
)
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
2
)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_v2a_ca_scale.squeeze(2)
) + audio_v2a_ca_shift.squeeze(2)
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
if use_v2a_cross_attention:
mod_norm_hidden_states = norm_hidden_states * (
1 + video_v2a_ca_scale.squeeze(2)
) + video_v2a_ca_shift.squeeze(2)
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
1 + audio_v2a_ca_scale.squeeze(2)
) + audio_v2a_ca_shift.squeeze(2)
v2a_attn_hidden_states = self.video_to_audio_attn(
mod_norm_audio_hidden_states,
encoder_hidden_states=mod_norm_hidden_states,
query_rotary_emb=ca_audio_rotary_emb,
key_rotary_emb=ca_video_rotary_emb,
attention_mask=v2a_cross_attention_mask,
)
v2a_attn_hidden_states = self.video_to_audio_attn(
mod_norm_audio_hidden_states,
encoder_hidden_states=mod_norm_hidden_states,
query_rotary_emb=ca_audio_rotary_emb,
key_rotary_emb=ca_video_rotary_emb,
attention_mask=v2a_cross_attention_mask,
)
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
# 4. Feedforward
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
@@ -918,6 +1118,8 @@ class LTX2VideoTransformer3DModel(
pos_embed_max_pos: int = 20,
base_height: int = 2048,
base_width: int = 2048,
gated_attn: bool = False,
cross_attn_mod: bool = False,
audio_in_channels: int = 128, # Audio Arguments
audio_out_channels: int | None = 128,
audio_patch_size: int = 1,
@@ -929,6 +1131,8 @@ class LTX2VideoTransformer3DModel(
audio_pos_embed_max_pos: int = 20,
audio_sampling_rate: int = 16000,
audio_hop_length: int = 160,
audio_gated_attn: bool = False,
audio_cross_attn_mod: bool = False,
num_layers: int = 48, # Shared arguments
activation_fn: str = "gelu-approximate",
qk_norm: str = "rms_norm_across_heads",
@@ -943,6 +1147,8 @@ class LTX2VideoTransformer3DModel(
timestep_scale_multiplier: int = 1000,
cross_attn_timestep_scale_multiplier: int = 1000,
rope_type: str = "interleaved",
use_prompt_embeddings=True,
perturbed_attn: bool = False,
) -> None:
super().__init__()
@@ -956,17 +1162,25 @@ class LTX2VideoTransformer3DModel(
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
# 2. Prompt embeddings
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=audio_inner_dim
)
if use_prompt_embeddings:
# LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=caption_channels, hidden_size=audio_inner_dim
)
# 3. Timestep Modulation Params and Embedding
self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3
# 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
# time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
video_time_emb_mod_params = 9 if cross_attn_mod else 6
audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
self.time_embed = LTX2AdaLayerNormSingle(
inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False
)
self.audio_time_embed = LTX2AdaLayerNormSingle(
audio_inner_dim, num_mod_params=6, use_additional_conditions=False
audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False
)
# 3.2. Global Cross Attention Modulation Parameters
@@ -995,6 +1209,13 @@ class LTX2VideoTransformer3DModel(
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
if self.prompt_modulation:
self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
audio_inner_dim, num_mod_params=2, use_additional_conditions=False
)
# 4. Rotary Positional Embeddings (RoPE)
# Self-Attention
self.rope = LTX2AudioVideoRotaryPosEmbed(
@@ -1071,6 +1292,10 @@ class LTX2VideoTransformer3DModel(
audio_num_attention_heads=audio_num_attention_heads,
audio_attention_head_dim=audio_attention_head_dim,
audio_cross_attention_dim=audio_cross_attention_dim,
video_gated_attn=gated_attn,
video_cross_attn_adaln=cross_attn_mod,
audio_gated_attn=audio_gated_attn,
audio_cross_attn_adaln=audio_cross_attn_mod,
qk_norm=qk_norm,
activation_fn=activation_fn,
attention_bias=attention_bias,
@@ -1078,6 +1303,7 @@ class LTX2VideoTransformer3DModel(
eps=norm_eps,
elementwise_affine=norm_elementwise_affine,
rope_type=rope_type,
perturbed_attn=perturbed_attn,
)
for _ in range(num_layers)
]
@@ -1101,6 +1327,8 @@ class LTX2VideoTransformer3DModel(
audio_encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
audio_timestep: torch.LongTensor | None = None,
sigma: torch.Tensor | None = None,
audio_sigma: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
audio_encoder_attention_mask: torch.Tensor | None = None,
num_frames: int | None = None,
@@ -1110,6 +1338,10 @@ class LTX2VideoTransformer3DModel(
audio_num_frames: int | None = None,
video_coords: torch.Tensor | None = None,
audio_coords: torch.Tensor | None = None,
isolate_modalities: bool = False,
spatio_temporal_guidance_blocks: list[int] | None = None,
perturbation_mask: torch.Tensor | None = None,
use_cross_timestep: bool = False,
attention_kwargs: dict[str, Any] | None = None,
return_dict: bool = True,
) -> torch.Tensor:
@@ -1131,6 +1363,13 @@ class LTX2VideoTransformer3DModel(
audio_timestep (`torch.Tensor`, *optional*):
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
params. This is only used by certain pipelines such as the I2V pipeline.
sigma (`torch.Tensor`, *optional*):
Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in
models such as LTX-2.3.
audio_sigma (`torch.Tensor`, *optional*):
Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in
models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to
the provided `sigma` value.
encoder_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
@@ -1152,6 +1391,21 @@ class LTX2VideoTransformer3DModel(
audio_coords (`torch.Tensor`, *optional*):
The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
`(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
isolate_modalities (`bool`, *optional*, defaults to `False`):
Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio)
cross attention (for all blocks). Use for modality guidance in LTX-2.3.
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the
self-attention operations by simply using the values rather than the full scaled dot-product attention
(SDPA) operation. If `None` or empty, STG will not be applied to any block.
perturbation_mask (`torch.Tensor`, *optional*):
Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch
elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is
not supplied, will default to applying STG (perturbing) all batch elements.
use_cross_timestep (`bool` *optional*, defaults to `False`):
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
`False` is the legacy LTX-2.0 behavior.
attention_kwargs (`dict[str, Any]`, *optional*):
Optional dict of keyword args to be passed to the attention processor.
return_dict (`bool`, *optional*, defaults to `True`):
@@ -1165,6 +1419,7 @@ class LTX2VideoTransformer3DModel(
"""
# Determine timestep for audio.
audio_timestep = audio_timestep if audio_timestep is not None else timestep
audio_sigma = audio_sigma if audio_sigma is not None else sigma
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
@@ -1223,14 +1478,28 @@ class LTX2VideoTransformer3DModel(
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
if self.prompt_modulation:
# LTX-2.3
temb_prompt, _ = self.prompt_adaln(
sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
temb_prompt_audio, _ = self.audio_prompt_adaln(
audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
)
temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
else:
temb_prompt = temb_prompt_audio = None
# 3.2. Prepare global modality cross attention modulation parameters
video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten()
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
timestep.flatten(),
video_ca_timestep,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
batch_size=batch_size,
hidden_dtype=hidden_states.dtype,
)
@@ -1239,13 +1508,14 @@ class LTX2VideoTransformer3DModel(
)
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten()
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
audio_timestep.flatten(),
audio_ca_timestep,
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
batch_size=batch_size,
hidden_dtype=audio_hidden_states.dtype,
)
@@ -1254,15 +1524,30 @@ class LTX2VideoTransformer3DModel(
)
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
# 4. Prepare prompt embeddings
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
# 4. Prepare prompt embeddings (LTX-2.0)
if self.config.use_prompt_embeddings:
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
audio_encoder_hidden_states = audio_encoder_hidden_states.view(
batch_size, -1, audio_hidden_states.size(-1)
)
# 5. Run transformer blocks
for block in self.transformer_blocks:
spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or []
if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None:
# If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements.
perturbation_mask = torch.zeros((batch_size,))
if perturbation_mask is not None and perturbation_mask.ndim == 1:
perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
stg_blocks = set(spatio_temporal_guidance_blocks)
for block_idx, block in enumerate(self.transformer_blocks):
block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None
block_all_perturbed = all_perturbed if block_idx in stg_blocks else False
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
block,
@@ -1276,12 +1561,22 @@ class LTX2VideoTransformer3DModel(
audio_cross_attn_scale_shift,
video_cross_attn_a2v_gate,
audio_cross_attn_v2a_gate,
temb_prompt,
temb_prompt_audio,
video_rotary_emb,
audio_rotary_emb,
video_cross_attn_rotary_emb,
audio_cross_attn_rotary_emb,
encoder_attention_mask,
audio_encoder_attention_mask,
None, # self_attention_mask
None, # audio_self_attention_mask
None, # a2v_cross_attention_mask
None, # v2a_cross_attention_mask
not isolate_modalities, # use_a2v_cross_attention
not isolate_modalities, # use_v2a_cross_attention
block_perturbation_mask,
block_all_perturbed,
)
else:
hidden_states, audio_hidden_states = block(
@@ -1295,12 +1590,22 @@ class LTX2VideoTransformer3DModel(
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
temb_ca_gate=video_cross_attn_a2v_gate,
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
temb_prompt=temb_prompt,
temb_prompt_audio=temb_prompt_audio,
video_rotary_emb=video_rotary_emb,
audio_rotary_emb=audio_rotary_emb,
ca_video_rotary_emb=video_cross_attn_rotary_emb,
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
audio_encoder_attention_mask=audio_encoder_attention_mask,
self_attention_mask=None,
audio_self_attention_mask=None,
a2v_cross_attention_mask=None,
v2a_cross_attention_mask=None,
use_a2v_cross_attention=not isolate_modalities,
use_v2a_cross_attention=not isolate_modalities,
perturbation_mask=block_perturbation_mask,
all_perturbed=block_all_perturbed,
)
# 6. Output layers (including unpatchification)

View File

@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
if hasattr(self.language_model, "_get_initial_cache_position"):
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs

View File

@@ -720,6 +720,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
super().__init__(config)
self.model = LDMBertEncoder(config)
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
self.post_init()
def forward(
self,

View File

@@ -28,7 +28,7 @@ else:
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
@@ -44,7 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .pipeline_ltx2_condition import LTX2ConditionPipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
else:
import sys

View File

@@ -1,3 +1,5 @@
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -9,6 +11,79 @@ from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
def per_layer_masked_mean_norm(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: str | torch.device,
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
):
"""
Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states.
Respects the padding of the hidden states.
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True)
norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps)
return norm_text_encoder_hidden_states
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
@@ -106,6 +181,7 @@ class LTX2TransformerBlock1d(nn.Module):
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
rope_type: str = "interleaved",
apply_gated_attention: bool = False,
):
super().__init__()
@@ -115,8 +191,9 @@ class LTX2TransformerBlock1d(nn.Module):
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
rope_type=rope_type,
apply_gated_attention=apply_gated_attention,
processor=LTX2AudioVideoAttnProcessor(),
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
@@ -160,6 +237,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved",
gated_attention: bool = False,
):
super().__init__()
self.num_attention_heads = num_attention_heads
@@ -188,6 +266,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
rope_type=rope_type,
apply_gated_attention=gated_attention,
)
for _ in range(num_layers)
]
@@ -260,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
@register_to_config
def __init__(
self,
caption_channels: int,
text_proj_in_factor: int,
video_connector_num_attention_heads: int,
video_connector_attention_head_dim: int,
video_connector_num_layers: int,
video_connector_num_learnable_registers: int | None,
audio_connector_num_attention_heads: int,
audio_connector_attention_head_dim: int,
audio_connector_num_layers: int,
audio_connector_num_learnable_registers: int | None,
connector_rope_base_seq_len: int,
rope_theta: float,
rope_double_precision: bool,
causal_temporal_positioning: bool,
caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size
text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B
video_connector_num_attention_heads: int = 30,
video_connector_attention_head_dim: int = 128,
video_connector_num_layers: int = 2,
video_connector_num_learnable_registers: int | None = 128,
video_gated_attn: bool = False,
audio_connector_num_attention_heads: int = 30,
audio_connector_attention_head_dim: int = 128,
audio_connector_num_layers: int = 2,
audio_connector_num_learnable_registers: int | None = 128,
audio_gated_attn: bool = False,
connector_rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved",
per_modality_projections: bool = False,
video_hidden_dim: int = 4096,
audio_hidden_dim: int = 2048,
proj_bias: bool = False,
):
super().__init__()
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
text_encoder_dim = caption_channels * text_proj_in_factor
if per_modality_projections:
self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias)
self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias)
else:
self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
@@ -288,6 +379,7 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
gated_attention=video_gated_attn,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
@@ -299,26 +391,86 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
gated_attention=audio_gated_attn,
)
def forward(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
):
# Convert to additive attention mask, if necessary
if not additive_mask:
text_dtype = text_encoder_hidden_states.dtype
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
self,
text_encoder_hidden_states: torch.Tensor,
attention_mask: torch.Tensor,
padding_side: str = "left",
scale_factor: int = 8,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text
embeddings for the LTX-2.X DiT models.
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
Args:
text_encoder_hidden_states (`torch.Tensor`)):
Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len,
caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened.
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked
positions.
padding_side (`str`, *optional*, defaults to `"left"`):
The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to
`"left"` as this is what the default Gemma3-12B text encoder uses. Only used if
`per_modality_projections` is `False` (LTX-2.0 models).
scale_factor (`int`, *optional*, defaults to `8`):
Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False`
(LTX-2.0 models).
"""
if text_encoder_hidden_states.ndim == 3:
# Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor]
text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1))
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
if self.config.per_modality_projections:
# LTX-2.3
norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3)
bool_mask = attention_mask.bool().unsqueeze(-1)
norm_text_encoder_hidden_states = torch.where(
bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states)
)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
# Rescale norms with respect to video and audio dims for feature extractors
video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels)
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels)
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
return video_text_embedding, audio_text_embedding, new_attn_mask
# Per-Modality Feature extractors
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
else:
# LTX-2.0
sequence_lengths = attention_mask.sum(dim=-1)
norm_text_encoder_hidden_states = per_layer_masked_mean_norm(
text_hidden_states=text_encoder_hidden_states,
sequence_lengths=sequence_lengths,
device=text_encoder_hidden_states.device,
padding_side=padding_side,
scale_factor=scale_factor,
)
text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states)
video_text_emb_proj = text_emb_proj
audio_text_emb_proj = text_emb_proj
# Convert to additive attention mask for connectors
text_dtype = video_text_emb_proj.dtype
attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype)
attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
add_attn_mask = attention_mask * torch.finfo(text_dtype).max
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask)
# Convert video attn mask to binary (multiplicative) mask and mask video text embedding
binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64)
binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * binary_attn_mask
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask)
return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1)

View File

@@ -195,7 +195,8 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
rational_spatial_scale: float | None = 2.0,
rational_spatial_scale: float = 2.0,
use_rational_resampler: bool = True,
):
super().__init__()
@@ -220,7 +221,7 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_spatial_scale is not None:
if use_rational_resampler:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
else:
self.upsampler = torch.nn.Sequential(

View File

@@ -18,7 +18,7 @@ from typing import Any, Callable
import numpy as np
import torch
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
@@ -31,7 +31,7 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from .vocoder import LTX2Vocoder
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
if is_torch_xla_available():
@@ -209,7 +209,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
"""
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
_optional_components = []
_optional_components = ["processor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
@@ -221,7 +221,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
connectors: LTX2TextConnectors,
transformer: LTX2VideoTransformer3DModel,
vocoder: LTX2Vocoder,
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
processor: Gemma3Processor | None = None,
):
super().__init__()
@@ -234,6 +235,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
transformer=transformer,
vocoder=vocoder,
scheduler=scheduler,
processor=processor,
)
self.vae_spatial_compression_ratio = (
@@ -268,73 +270,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
)
@staticmethod
def _pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: str | torch.device,
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
def _get_gemma_prompt_embeds(
self,
prompt: str | list[str],
@@ -387,16 +322,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer.padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=dtype)
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
@@ -494,6 +420,50 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
@torch.no_grad()
def enhance_prompt(
self,
prompt: str,
system_prompt: str,
max_new_tokens: int = 512,
seed: int = 10,
generator: torch.Generator | None = None,
generation_kwargs: dict[str, Any] | None = None,
device: str | torch.device | None = None,
):
"""
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
"""
device = device or self._execution_device
if generation_kwargs is None:
# Set to default generation kwargs
generation_kwargs = {"do_sample": True, "temperature": 0.7}
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": f"user prompt: {prompt}"},
]
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device)
self.text_encoder.to(device)
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
# so manually apply a seed for reproducible generation.
if generator is not None:
# Overwrite seed to generator's initial seed
seed = generator.initial_seed()
torch.manual_seed(seed)
generated_sequences = self.text_encoder.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
**generation_kwargs,
) # tensor of shape [batch_size, seq_len]
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return enhanced_prompt
def check_inputs(
self,
prompt,
@@ -504,6 +474,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
spatio_temporal_guidance_blocks=None,
stg_scale=None,
audio_stg_scale=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -547,6 +520,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
f" {negative_prompt_attention_mask.shape}."
)
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
raise ValueError(
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
)
@staticmethod
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
@@ -734,7 +713,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = self._create_noised_state(latents, noise_scale, generator)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
@@ -749,6 +727,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
latents = self._pack_audio_latents(latents)
return latents
def convert_velocity_to_x0(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
return sample_x0
def convert_x0_to_velocity(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
return sample_v
@property
def guidance_scale(self):
return self._guidance_scale
@@ -757,9 +753,41 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
def guidance_rescale(self):
return self._guidance_rescale
@property
def stg_scale(self):
return self._stg_scale
@property
def modality_scale(self):
return self._modality_scale
@property
def audio_guidance_scale(self):
return self._audio_guidance_scale
@property
def audio_guidance_rescale(self):
return self._audio_guidance_rescale
@property
def audio_stg_scale(self):
return self._audio_stg_scale
@property
def audio_modality_scale(self):
return self._audio_modality_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
@property
def do_spatio_temporal_guidance(self):
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
@property
def do_modality_isolation_guidance(self):
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
@property
def num_timesteps(self):
@@ -791,7 +819,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
sigmas: list[float] | None = None,
timesteps: list[int] = None,
guidance_scale: float = 4.0,
stg_scale: float = 0.0,
modality_scale: float = 1.0,
guidance_rescale: float = 0.0,
audio_guidance_scale: float | None = None,
audio_stg_scale: float | None = None,
audio_modality_scale: float | None = None,
audio_guidance_rescale: float | None = None,
spatio_temporal_guidance_blocks: list[int] | None = None,
noise_scale: float = 0.0,
num_videos_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
@@ -803,6 +838,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
negative_prompt_attention_mask: torch.Tensor | None = None,
decode_timestep: float | list[float] = 0.0,
decode_noise_scale: float | list[float] | None = None,
use_cross_timestep: bool = False,
system_prompt: str | None = None,
prompt_max_new_tokens: int = 512,
prompt_enhancement_kwargs: dict[str, Any] | None = None,
prompt_enhancement_seed: int = 10,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
@@ -841,13 +881,47 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
a separate value `audio_guidance_scale` for the audio modality).
stg_scale (`float`, *optional*, defaults to `0.0`):
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
where we move the sample away from a weak sample from a perturbed version of the denoising model.
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
means that STG is disabled.
modality_scale (`float`, *optional*, defaults to `1.0`):
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
disabled.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
using zero terminal SNR. Used for the video modality.
audio_guidance_scale (`float`, *optional* defaults to `None`):
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
`guidance_scale`.
audio_stg_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
video value `stg_scale`.
audio_modality_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
audio. If `None`, defaults to the video value `modality_scale`.
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
`guidance_rescale`.
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
and `[28]` is recommended for LTX-2.3.
noise_scale (`float`, *optional*, defaults to `0.0`):
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
the `latents` and `audio_latents` before continue denoising.
@@ -878,6 +952,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
use_cross_timestep (`bool` *optional*, defaults to `False`):
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
`False` is the legacy LTX-2.0 behavior.
system_prompt (`str`, *optional*, defaults to `None`):
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
performed.
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
The maximum number of new tokens to generate when performing prompt enhancement.
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
`do_sample=True` and `temperature=0.7` will be used. See
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
for more details.
prompt_enhancement_seed (`int`, *optional*, default to `10`):
Random seed for any random operations during prompt enhancement.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -910,6 +1002,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
audio_guidance_scale = audio_guidance_scale or guidance_scale
audio_stg_scale = audio_stg_scale or stg_scale
audio_modality_scale = audio_modality_scale or modality_scale
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
@@ -920,10 +1017,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
stg_scale=stg_scale,
audio_stg_scale=audio_stg_scale,
)
# Per-modality guidance scales (video, audio)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._modality_scale = modality_scale
self._guidance_rescale = guidance_rescale
self._audio_guidance_scale = audio_guidance_scale
self._audio_stg_scale = audio_stg_scale
self._audio_modality_scale = audio_modality_scale
self._audio_guidance_rescale = audio_guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -939,6 +1047,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
device = self._execution_device
# 3. Prepare text embeddings
if system_prompt is not None and prompt is not None:
prompt = self.enhance_prompt(
prompt=prompt,
system_prompt=system_prompt,
max_new_tokens=prompt_max_new_tokens,
seed=prompt_enhancement_seed,
generator=generator,
generation_kwargs=prompt_enhancement_kwargs,
device=device,
)
(
prompt_embeds,
prompt_attention_mask,
@@ -960,9 +1079,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
if getattr(self, "tokenizer", None) is not None:
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
)
# 4. Prepare latent variables
@@ -984,7 +1105,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
# video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
@@ -1041,7 +1162,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_image_seq_len", 1024),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.95),
@@ -1069,11 +1190,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
rope_interpolation_scale = (
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
video_coords = self.transformer.rope.prepare_video_coords(
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
@@ -1111,6 +1227,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
encoder_hidden_states=connector_prompt_embeds,
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
@@ -1120,7 +1237,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
isolate_modalities=False,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
@@ -1128,24 +1248,155 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
noise_pred_audio = noise_pred_audio.float()
if self.do_classifier_free_guidance:
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
noise_pred_video_text - noise_pred_video_uncond
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_text, i, self.scheduler
)
# Use delta formulation as it works more nicely with multiple guidance terms
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
)
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_text
)
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
noise_pred_audio_text - noise_pred_audio_uncond
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
if i == 0:
# Only split values that remain constant throughout the loop once
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
video_pos_ids = video_coords.chunk(2, dim=0)[0]
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
# Split values that vary each denoising loop iteration
timestep = timestep.chunk(2, dim=0)[0]
else:
video_cfg_delta = audio_cfg_delta = 0
video_prompt_embeds = connector_prompt_embeds
audio_prompt_embeds = connector_audio_prompt_embeds
prompt_attn_mask = connector_attention_mask
video_pos_ids = video_coords
audio_pos_ids = audio_coords
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
if self.do_spatio_temporal_guidance:
with self.transformer.cache_context("uncond_stg"):
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
isolate_modalities=False,
# Use STG at given blocks to perturb model
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_stg, i, self.scheduler
)
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
)
if self.guidance_rescale > 0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred_video = rescale_noise_cfg(
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
)
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
else:
video_stg_delta = audio_stg_delta = 0
if self.do_modality_isolation_guidance:
with self.transformer.cache_context("uncond_modality"):
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
# Turn off A2V and V2A cross attn to isolate video and audio modalities
isolate_modalities=True,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_modality, i, self.scheduler
)
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
)
video_modality_delta = (self.modality_scale - 1) * (
noise_pred_video - noise_pred_video_uncond_modality
)
audio_modality_delta = (self.audio_modality_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_modality
)
else:
video_modality_delta = audio_modality_delta = 0
# Now apply all guidance terms
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
# Apply LTX-2.X guidance rescaling
if self.guidance_rescale > 0:
noise_pred_video = rescale_noise_cfg(
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
)
else:
noise_pred_video = noise_pred_video_g
if self.audio_guidance_rescale > 0:
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
)
else:
noise_pred_audio = noise_pred_audio_g
# Convert back to velocity for scheduler
noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler)
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
@@ -1177,9 +1428,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
@@ -1187,6 +1435,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
audio = audio_latents
else:
@@ -1209,6 +1460,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)

View File

@@ -33,7 +33,7 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from .vocoder import LTX2Vocoder
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
if is_torch_xla_available():
@@ -254,7 +254,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
connectors: LTX2TextConnectors,
transformer: LTX2VideoTransformer3DModel,
vocoder: LTX2Vocoder,
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
):
super().__init__()
@@ -300,74 +300,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
)
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
def _pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: str | torch.device,
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
@@ -421,16 +353,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer.padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=dtype)
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
@@ -541,6 +464,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
negative_prompt_attention_mask=None,
latents=None,
audio_latents=None,
spatio_temporal_guidance_blocks=None,
stg_scale=None,
audio_stg_scale=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -597,6 +523,12 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
f" using the `_unpack_audio_latents` method)."
)
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
raise ValueError(
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
)
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
@@ -984,6 +916,24 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
latents = self._pack_audio_latents(latents)
return latents
def convert_velocity_to_x0(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
return sample_x0
def convert_x0_to_velocity(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
return sample_v
@property
def guidance_scale(self):
return self._guidance_scale
@@ -992,9 +942,41 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
def guidance_rescale(self):
return self._guidance_rescale
@property
def stg_scale(self):
return self._stg_scale
@property
def modality_scale(self):
return self._modality_scale
@property
def audio_guidance_scale(self):
return self._audio_guidance_scale
@property
def audio_guidance_rescale(self):
return self._audio_guidance_rescale
@property
def audio_stg_scale(self):
return self._audio_stg_scale
@property
def audio_modality_scale(self):
return self._audio_modality_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
@property
def do_spatio_temporal_guidance(self):
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
@property
def do_modality_isolation_guidance(self):
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
@property
def num_timesteps(self):
@@ -1027,7 +1009,14 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
sigmas: list[float] | None = None,
timesteps: list[float] | None = None,
guidance_scale: float = 4.0,
stg_scale: float = 0.0,
modality_scale: float = 1.0,
guidance_rescale: float = 0.0,
audio_guidance_scale: float | None = None,
audio_stg_scale: float | None = None,
audio_modality_scale: float | None = None,
audio_guidance_rescale: float | None = None,
spatio_temporal_guidance_blocks: list[int] | None = None,
noise_scale: float | None = None,
num_videos_per_prompt: int | None = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
@@ -1039,6 +1028,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
negative_prompt_attention_mask: torch.Tensor | None = None,
decode_timestep: float | list[float] = 0.0,
decode_noise_scale: float | list[float] | None = None,
use_cross_timestep: bool = False,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
@@ -1079,13 +1069,47 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
a separate value `audio_guidance_scale` for the audio modality).
stg_scale (`float`, *optional*, defaults to `0.0`):
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
where we move the sample away from a weak sample from a perturbed version of the denoising model.
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
means that STG is disabled.
modality_scale (`float`, *optional*, defaults to `1.0`):
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
disabled.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
using zero terminal SNR. Used for the video modality.
audio_guidance_scale (`float`, *optional* defaults to `None`):
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
`guidance_scale`.
audio_stg_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
video value `stg_scale`.
audio_modality_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
audio. If `None`, defaults to the video value `modality_scale`.
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
`guidance_rescale`.
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
and `[28]` is recommended for LTX-2.3.
noise_scale (`float`, *optional*, defaults to `None`):
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the
@@ -1117,6 +1141,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
use_cross_timestep (`bool` *optional*, defaults to `False`):
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
`False` is the legacy LTX-2.0 behavior.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1149,6 +1177,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
audio_guidance_scale = audio_guidance_scale or guidance_scale
audio_stg_scale = audio_stg_scale or stg_scale
audio_modality_scale = audio_modality_scale or modality_scale
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
@@ -1161,10 +1194,21 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
negative_prompt_attention_mask=negative_prompt_attention_mask,
latents=latents,
audio_latents=audio_latents,
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
stg_scale=stg_scale,
audio_stg_scale=audio_stg_scale,
)
# Per-modality guidance scales (video, audio)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._modality_scale = modality_scale
self._guidance_rescale = guidance_rescale
self._audio_guidance_scale = audio_guidance_scale
self._audio_stg_scale = audio_stg_scale
self._audio_modality_scale = audio_modality_scale
self._audio_guidance_rescale = audio_guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -1208,9 +1252,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
if getattr(self, "tokenizer", None) is not None:
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
)
# 4. Prepare latent variables
@@ -1222,7 +1268,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
)
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
video_sequence_length = latent_num_frames * latent_height * latent_width
# video_sequence_length = latent_num_frames * latent_height * latent_width
num_channels_latents = self.transformer.config.in_channels
latents, conditioning_mask, clean_latents = self.prepare_latents(
@@ -1272,7 +1318,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_image_seq_len", 1024),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.95),
@@ -1301,11 +1347,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
rope_interpolation_scale = (
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
video_coords = self.transformer.rope.prepare_video_coords(
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
@@ -1344,6 +1385,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
@@ -1353,7 +1395,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
isolate_modalities=False,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
@@ -1361,41 +1406,172 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
noise_pred_audio = noise_pred_audio.float()
if self.do_classifier_free_guidance:
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
noise_pred_video_text - noise_pred_video_uncond
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_text, i, self.scheduler
)
# Use delta formulation as it works more nicely with multiple guidance terms
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
)
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_text
)
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
noise_pred_audio_text - noise_pred_audio_uncond
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
if i == 0:
# Only split values that remain constant throughout the loop once
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
video_pos_ids = video_coords.chunk(2, dim=0)[0]
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
# Split values that vary each denoising loop iteration
timestep = timestep.chunk(2, dim=0)[0]
video_timestep = video_timestep.chunk(2, dim=0)[0]
else:
video_cfg_delta = audio_cfg_delta = 0
video_prompt_embeds = connector_prompt_embeds
audio_prompt_embeds = connector_audio_prompt_embeds
prompt_attn_mask = connector_attention_mask
video_pos_ids = video_coords
audio_pos_ids = audio_coords
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
if self.do_spatio_temporal_guidance:
with self.transformer.cache_context("uncond_stg"):
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
isolate_modalities=False,
# Use STG at given blocks to perturb model
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_stg, i, self.scheduler
)
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
)
if self.guidance_rescale > 0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred_video = rescale_noise_cfg(
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
)
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
else:
video_stg_delta = audio_stg_delta = 0
if self.do_modality_isolation_guidance:
with self.transformer.cache_context("uncond_modality"):
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
# Turn off A2V and V2A cross attn to isolate video and audio modalities
isolate_modalities=True,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_modality, i, self.scheduler
)
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
)
video_modality_delta = (self.modality_scale - 1) * (
noise_pred_video - noise_pred_video_uncond_modality
)
audio_modality_delta = (self.audio_modality_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_modality
)
else:
video_modality_delta = audio_modality_delta = 0
# Now apply all guidance terms
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
# Apply LTX-2.X guidance rescaling
if self.guidance_rescale > 0:
noise_pred_video = rescale_noise_cfg(
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
)
else:
noise_pred_video = noise_pred_video_g
if self.audio_guidance_rescale > 0:
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
)
else:
noise_pred_audio = noise_pred_audio_g
# NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG
bsz = noise_pred_video.size(0)
sigma = self.scheduler.sigmas[i]
# Convert the noise_pred_video velocity model prediction into a sample (x0) prediction
denoised_sample = latents - noise_pred_video * sigma
# Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The
# conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always
# use conditions), with intermediate values specifying how strongly to follow the conditions.
# NOTE: this operation should be applied in sample (x0) space and not velocity space (which is the
# space the denoising model outputs are in)
denoised_sample_cond = (
denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
).to(noise_pred_video.dtype)
# Convert the denoised (x0) sample back to a velocity for the scheduler
denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype)
noise_pred_video = self.convert_x0_to_velocity(latents, denoised_sample_cond, i, self.scheduler)
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
# NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in
# the step method (such as _step_index)
@@ -1425,9 +1601,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
@@ -1435,6 +1608,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
audio = audio_latents
else:
@@ -1457,6 +1633,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)

View File

@@ -18,7 +18,7 @@ from typing import Any, Callable
import numpy as np
import torch
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
@@ -32,7 +32,7 @@ from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .connectors import LTX2TextConnectors
from .pipeline_output import LTX2PipelineOutput
from .vocoder import LTX2Vocoder
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
if is_torch_xla_available():
@@ -212,7 +212,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
"""
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
_optional_components = []
_optional_components = ["processor"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
def __init__(
@@ -224,7 +224,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
connectors: LTX2TextConnectors,
transformer: LTX2VideoTransformer3DModel,
vocoder: LTX2Vocoder,
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
processor: Gemma3Processor | None = None,
):
super().__init__()
@@ -237,6 +238,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
transformer=transformer,
vocoder=vocoder,
scheduler=scheduler,
processor=processor,
)
self.vae_spatial_compression_ratio = (
@@ -271,74 +273,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
)
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
def _pack_text_embeds(
text_hidden_states: torch.Tensor,
sequence_lengths: torch.Tensor,
device: str | torch.device,
padding_side: str = "left",
scale_factor: int = 8,
eps: float = 1e-6,
) -> torch.Tensor:
"""
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
per-layer in a masked fashion (only over non-padded positions).
Args:
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
The number of valid (non-padded) tokens for each batch instance.
device: (`str` or `torch.device`, *optional*):
torch device to place the resulting embeddings on
padding_side: (`str`, *optional*, defaults to `"left"`):
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
scale_factor (`int`, *optional*, defaults to `8`):
Scaling factor to multiply the normalized hidden states by.
eps (`float`, *optional*, defaults to `1e-6`):
A small positive value for numerical stability when performing normalization.
Returns:
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
Normed and flattened text encoder hidden states.
"""
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
original_dtype = text_hidden_states.dtype
# Create padding mask
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
if padding_side == "right":
# For right padding, valid tokens are from 0 to sequence_length-1
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
elif padding_side == "left":
# For left padding, valid tokens are from (T - sequence_length) to T-1
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
mask = token_indices >= start_indices # [B, T]
else:
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
# Normalization
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
normalized_hidden_states = normalized_hidden_states * scale_factor
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.flatten(2)
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
return normalized_hidden_states
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
@@ -392,16 +326,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
)
text_encoder_hidden_states = text_encoder_outputs.hidden_states
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
sequence_lengths = prompt_attention_mask.sum(dim=-1)
prompt_embeds = self._pack_text_embeds(
text_encoder_hidden_states,
sequence_lengths,
device=device,
padding_side=self.tokenizer.padding_side,
scale_factor=scale_factor,
)
prompt_embeds = prompt_embeds.to(dtype=dtype)
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
@@ -500,6 +425,57 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
@torch.no_grad()
def enhance_prompt(
self,
image: PipelineImageInput,
prompt: str,
system_prompt: str,
max_new_tokens: int = 512,
seed: int = 10,
generator: torch.Generator | None = None,
generation_kwargs: dict[str, Any] | None = None,
device: str | torch.device | None = None,
):
"""
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
"""
device = device or self._execution_device
if generation_kwargs is None:
# Set to default generation kwargs
generation_kwargs = {"do_sample": True, "temperature": 0.7}
messages = [
{"role": "system", "content": system_prompt},
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
],
},
]
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device)
self.text_encoder.to(device)
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
# so manually apply a seed for reproducible generation.
if generator is not None:
# Overwrite seed to generator's initial seed
seed = generator.initial_seed()
torch.manual_seed(seed)
generated_sequences = self.text_encoder.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
**generation_kwargs,
) # tensor of shape [batch_size, seq_len]
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
return enhanced_prompt
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
def check_inputs(
self,
@@ -511,6 +487,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
spatio_temporal_guidance_blocks=None,
stg_scale=None,
audio_stg_scale=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -554,6 +533,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
f" {negative_prompt_attention_mask.shape}."
)
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
raise ValueError(
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
)
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
@@ -788,7 +773,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._create_noised_state(latents, noise_scale, generator)
return latents.to(device=device, dtype=dtype)
# TODO: confirm whether this logic is correct
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
@@ -803,6 +787,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
latents = self._pack_audio_latents(latents)
return latents
def convert_velocity_to_x0(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
return sample_x0
def convert_x0_to_velocity(
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
) -> torch.Tensor:
if scheduler is None:
scheduler = self.scheduler
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
return sample_v
@property
def guidance_scale(self):
return self._guidance_scale
@@ -811,9 +813,41 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
def guidance_rescale(self):
return self._guidance_rescale
@property
def stg_scale(self):
return self._stg_scale
@property
def modality_scale(self):
return self._modality_scale
@property
def audio_guidance_scale(self):
return self._audio_guidance_scale
@property
def audio_guidance_rescale(self):
return self._audio_guidance_rescale
@property
def audio_stg_scale(self):
return self._audio_stg_scale
@property
def audio_modality_scale(self):
return self._audio_modality_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
@property
def do_spatio_temporal_guidance(self):
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
@property
def do_modality_isolation_guidance(self):
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
@property
def num_timesteps(self):
@@ -846,7 +880,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
sigmas: list[float] | None = None,
timesteps: list[int] | None = None,
guidance_scale: float = 4.0,
stg_scale: float = 0.0,
modality_scale: float = 1.0,
guidance_rescale: float = 0.0,
audio_guidance_scale: float | None = None,
audio_stg_scale: float | None = None,
audio_modality_scale: float | None = None,
audio_guidance_rescale: float | None = None,
spatio_temporal_guidance_blocks: list[int] | None = None,
noise_scale: float = 0.0,
num_videos_per_prompt: int = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
@@ -858,6 +899,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
negative_prompt_attention_mask: torch.Tensor | None = None,
decode_timestep: float | list[float] = 0.0,
decode_noise_scale: float | list[float] | None = None,
use_cross_timestep: bool = False,
system_prompt: str | None = None,
prompt_max_new_tokens: int = 512,
prompt_enhancement_kwargs: dict[str, Any] | None = None,
prompt_enhancement_seed: int = 10,
output_type: str = "pil",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
@@ -898,13 +944,47 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
a separate value `audio_guidance_scale` for the audio modality).
stg_scale (`float`, *optional*, defaults to `0.0`):
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
where we move the sample away from a weak sample from a perturbed version of the denoising model.
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
means that STG is disabled.
modality_scale (`float`, *optional*, defaults to `1.0`):
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
disabled.
guidance_rescale (`float`, *optional*, defaults to 0.0):
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
using zero terminal SNR. Used for the video modality.
audio_guidance_scale (`float`, *optional* defaults to `None`):
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
`guidance_scale`.
audio_stg_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
video value `stg_scale`.
audio_modality_scale (`float`, *optional*, defaults to `None`):
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
audio. If `None`, defaults to the video value `modality_scale`.
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
`guidance_rescale`.
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
and `[28]` is recommended for LTX-2.3.
noise_scale (`float`, *optional*, defaults to `0.0`):
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
the `latents` and `audio_latents` before continue denoising.
@@ -935,6 +1015,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
use_cross_timestep (`bool` *optional*, defaults to `False`):
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
`False` is the legacy LTX-2.0 behavior.
system_prompt (`str`, *optional*, defaults to `None`):
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
performed.
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
The maximum number of new tokens to generate when performing prompt enhancement.
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
`do_sample=True` and `temperature=0.7` will be used. See
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
for more details.
prompt_enhancement_seed (`int`, *optional*, default to `10`):
Random seed for any random operations during prompt enhancement.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -967,6 +1065,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
audio_guidance_scale = audio_guidance_scale or guidance_scale
audio_stg_scale = audio_stg_scale or stg_scale
audio_modality_scale = audio_modality_scale or modality_scale
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt=prompt,
@@ -977,10 +1080,21 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
stg_scale=stg_scale,
audio_stg_scale=audio_stg_scale,
)
# Per-modality guidance scales (video, audio)
self._guidance_scale = guidance_scale
self._stg_scale = stg_scale
self._modality_scale = modality_scale
self._guidance_rescale = guidance_rescale
self._audio_guidance_scale = audio_guidance_scale
self._audio_stg_scale = audio_stg_scale
self._audio_modality_scale = audio_modality_scale
self._audio_guidance_rescale = audio_guidance_rescale
self._attention_kwargs = attention_kwargs
self._interrupt = False
self._current_timestep = None
@@ -996,6 +1110,18 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
device = self._execution_device
# 3. Prepare text embeddings
if system_prompt is not None and prompt is not None:
prompt = self.enhance_prompt(
image=image,
prompt=prompt,
system_prompt=system_prompt,
max_new_tokens=prompt_max_new_tokens,
seed=prompt_enhancement_seed,
generator=generator,
generation_kwargs=prompt_enhancement_kwargs,
device=device,
)
(
prompt_embeds,
prompt_attention_mask,
@@ -1017,9 +1143,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
if getattr(self, "tokenizer", None) is not None:
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
prompt_embeds, additive_attention_mask, additive_mask=True
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
)
# 4. Prepare latent variables
@@ -1041,7 +1169,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
raise ValueError(
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
)
video_sequence_length = latent_num_frames * latent_height * latent_width
# video_sequence_length = latent_num_frames * latent_height * latent_width
if latents is None:
image = self.video_processor.preprocess(image, height=height, width=width)
@@ -1105,7 +1233,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
# 5. Prepare timesteps
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
mu = calculate_shift(
video_sequence_length,
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_image_seq_len", 1024),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.95),
@@ -1134,11 +1262,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
self._num_timesteps = len(timesteps)
# 6. Prepare micro-conditions
rope_interpolation_scale = (
self.vae_temporal_compression_ratio / frame_rate,
self.vae_spatial_compression_ratio,
self.vae_spatial_compression_ratio,
)
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
video_coords = self.transformer.rope.prepare_video_coords(
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
@@ -1177,6 +1300,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
audio_encoder_hidden_states=connector_audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=connector_attention_mask,
audio_encoder_attention_mask=connector_attention_mask,
num_frames=latent_num_frames,
@@ -1186,7 +1310,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
audio_num_frames=audio_num_frames,
video_coords=video_coords,
audio_coords=audio_coords,
# rope_interpolation_scale=rope_interpolation_scale,
isolate_modalities=False,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
@@ -1194,24 +1321,154 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
noise_pred_audio = noise_pred_audio.float()
if self.do_classifier_free_guidance:
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
noise_pred_video_text - noise_pred_video_uncond
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_text, i, self.scheduler
)
# Use delta formulation as it works more nicely with multiple guidance terms
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
)
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_text
)
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
noise_pred_audio_text - noise_pred_audio_uncond
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
if i == 0:
# Only split values that remain constant throughout the loop once
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
video_pos_ids = video_coords.chunk(2, dim=0)[0]
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
# Split values that vary each denoising loop iteration
timestep = timestep.chunk(2, dim=0)[0]
video_timestep = video_timestep.chunk(2, dim=0)[0]
else:
video_cfg_delta = audio_cfg_delta = 0
video_prompt_embeds = connector_prompt_embeds
audio_prompt_embeds = connector_audio_prompt_embeds
prompt_attn_mask = connector_attention_mask
video_pos_ids = video_coords
audio_pos_ids = audio_coords
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
if self.do_spatio_temporal_guidance:
with self.transformer.cache_context("uncond_stg"):
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
isolate_modalities=False,
# Use STG at given blocks to perturb model
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_stg, i, self.scheduler
)
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
)
if self.guidance_rescale > 0:
# Based on 3.4. in https://huggingface.co/papers/2305.08891
noise_pred_video = rescale_noise_cfg(
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
)
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
else:
video_stg_delta = audio_stg_delta = 0
if self.do_modality_isolation_guidance:
with self.transformer.cache_context("uncond_modality"):
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
hidden_states=latents.to(dtype=prompt_embeds.dtype),
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
encoder_hidden_states=video_prompt_embeds,
audio_encoder_hidden_states=audio_prompt_embeds,
timestep=video_timestep,
audio_timestep=timestep,
sigma=timestep, # Used by LTX-2.3
encoder_attention_mask=prompt_attn_mask,
audio_encoder_attention_mask=prompt_attn_mask,
num_frames=latent_num_frames,
height=latent_height,
width=latent_width,
fps=frame_rate,
audio_num_frames=audio_num_frames,
video_coords=video_pos_ids,
audio_coords=audio_pos_ids,
# Turn off A2V and V2A cross attn to isolate video and audio modalities
isolate_modalities=True,
spatio_temporal_guidance_blocks=None,
perturbation_mask=None,
use_cross_timestep=use_cross_timestep,
attention_kwargs=attention_kwargs,
return_dict=False,
)
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
latents, noise_pred_video_uncond_modality, i, self.scheduler
)
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
)
video_modality_delta = (self.modality_scale - 1) * (
noise_pred_video - noise_pred_video_uncond_modality
)
audio_modality_delta = (self.audio_modality_scale - 1) * (
noise_pred_audio - noise_pred_audio_uncond_modality
)
else:
video_modality_delta = audio_modality_delta = 0
# Now apply all guidance terms
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
# Apply LTX-2.X guidance rescaling
if self.guidance_rescale > 0:
noise_pred_video = rescale_noise_cfg(
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
)
else:
noise_pred_video = noise_pred_video_g
if self.audio_guidance_rescale > 0:
noise_pred_audio = rescale_noise_cfg(
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
)
else:
noise_pred_audio = noise_pred_audio_g
# compute the previous noisy sample x_t -> x_t-1
noise_pred_video = self._unpack_latents(
@@ -1231,6 +1488,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
self.transformer_temporal_patch_size,
)
# Convert back to velocity for scheduler
noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler)
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
noise_pred_video = noise_pred_video[:, :, 1:]
noise_latents = latents[:, :, 1:]
pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
@@ -1268,9 +1529,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
self.transformer_spatial_patch_size,
self.transformer_temporal_patch_size,
)
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
audio_latents = self._denormalize_audio_latents(
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
@@ -1278,6 +1536,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
if output_type == "latent":
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
audio = audio_latents
else:
@@ -1300,6 +1561,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.vae.dtype)
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)

View File

@@ -8,6 +8,209 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
"""
Creates a Kaiser sinc kernel for low-pass filtering.
Args:
cutoff (`float`):
Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist
frequency).
half_width (`float`):
Used to determine the Kaiser window's beta parameter.
kernel_size:
Size of the Kaiser window (and ultimately the Kaiser sinc kernel).
Returns:
`torch.Tensor` of shape `(kernel_size,)`:
The Kaiser sinc kernel.
"""
delta_f = 4 * half_width
half_size = kernel_size // 2
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if amplitude > 50.0:
beta = 0.1102 * (amplitude - 8.7)
elif amplitude >= 21.0:
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
even = kernel_size % 2 == 0
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
if cutoff == 0.0:
filter = torch.zeros_like(time)
else:
time = 2 * cutoff * time
sinc = torch.where(
time == 0,
torch.ones_like(time),
torch.sin(math.pi * time) / math.pi / time,
)
filter = 2 * cutoff * window * sinc
filter = filter / filter.sum()
return filter
class DownSample1d(nn.Module):
"""1D low-pass filter for antialias downsampling."""
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
use_padding: bool = True,
padding_mode: str = "replicate",
persistent: bool = True,
):
super().__init__()
self.ratio = ratio
self.kernel_size = kernel_size or int(6 * ratio // 2) * 2
self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1
self.pad_right = self.kernel_size // 2
self.use_padding = use_padding
self.padding_mode = padding_mode
cutoff = 0.5 / ratio
half_width = 0.6 / ratio
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x expected shape: [batch_size, num_channels, hidden_dim]
num_channels = x.shape[1]
if self.use_padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels)
return x_filtered
class UpSample1d(nn.Module):
def __init__(
self,
ratio: int = 2,
kernel_size: int | None = None,
window_type: str = "kaiser",
padding_mode: str = "replicate",
persistent: bool = True,
):
super().__init__()
self.ratio = ratio
self.padding_mode = padding_mode
if window_type == "hann":
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser sinc filter is BigVGAN default
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2
self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2
sinc_filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
kernel_size=self.kernel_size,
)
self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x expected shape: [batch_size, num_channels, hidden_dim]
num_channels = x.shape[1]
x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode)
low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1)
x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels)
return x[..., self.pad_left : -self.pad_right]
class AntiAliasAct1d(nn.Module):
"""
Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples
to avoid aliasing.
"""
def __init__(
self,
act_fn: str | nn.Module,
ratio: int = 2,
kernel_size: int = 12,
**kwargs,
):
super().__init__()
self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size)
if isinstance(act_fn, str):
if act_fn == "snakebeta":
act_fn = SnakeBeta(**kwargs)
elif act_fn == "snake":
act_fn = SnakeBeta(**kwargs)
else:
act_fn = nn.LeakyReLU(**kwargs)
self.act = act_fn
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
class SnakeBeta(nn.Module):
"""
Implements the Snake and SnakeBeta activations, which help with learning periodic patterns.
"""
def __init__(
self,
channels: int,
alpha: float = 1.0,
eps: float = 1e-9,
trainable_params: bool = True,
logscale: bool = True,
use_beta: bool = True,
):
super().__init__()
self.eps = eps
self.logscale = logscale
self.use_beta = use_beta
self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
self.alpha.requires_grad = trainable_params
if use_beta:
self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
self.beta.requires_grad = trainable_params
def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
broadcast_shape = [1] * hidden_states.ndim
broadcast_shape[channel_dim] = -1
alpha = self.alpha.view(broadcast_shape)
if self.use_beta:
beta = self.beta.view(broadcast_shape)
if self.logscale:
alpha = torch.exp(alpha)
if self.use_beta:
beta = torch.exp(beta)
amplitude = beta if self.use_beta else alpha
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2)
return hidden_states
class ResBlock(nn.Module):
def __init__(
self,
@@ -15,12 +218,15 @@ class ResBlock(nn.Module):
kernel_size: int = 3,
stride: int = 1,
dilations: tuple[int, ...] = (1, 3, 5),
act_fn: str = "leaky_relu",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = False,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
@@ -28,6 +234,18 @@ class ResBlock(nn.Module):
for dilation in dilations
]
)
self.acts1 = nn.ModuleList()
for _ in range(len(self.convs1)):
if act_fn == "snakebeta":
act = SnakeBeta(channels, use_beta=True)
elif act_fn == "snake":
act = SnakeBeta(channels, use_beta=False)
else:
act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
if antialias:
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
self.acts1.append(act)
self.convs2 = nn.ModuleList(
[
@@ -35,12 +253,24 @@ class ResBlock(nn.Module):
for _ in range(len(dilations))
]
)
self.acts2 = nn.ModuleList()
for _ in range(len(self.convs2)):
if act_fn == "snakebeta":
act = SnakeBeta(channels, use_beta=True)
elif act_fn == "snake":
act = SnakeBeta(channels, use_beta=False)
else:
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
if antialias:
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
self.acts2.append(act)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2):
xt = act1(x)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = act2(xt)
xt = conv2(xt)
x = x + xt
return x
@@ -61,7 +291,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
upsample_factors: list[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: list[int] = [3, 7, 11],
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
act_fn: str = "leaky_relu",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = False,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
final_act_fn: str | None = "tanh", # tanh, clamp, None
final_bias: bool = True,
output_sampling_rate: int = 24000,
):
super().__init__()
@@ -69,7 +305,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.act_fn = act_fn
self.negative_slope = leaky_relu_negative_slope
self.final_act_fn = final_act_fn
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
@@ -83,6 +321,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
supported_act_fns = ["snakebeta", "snake", "leaky_relu"]
if self.act_fn not in supported_act_fns:
raise ValueError(
f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are "
f"{supported_act_fns}."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
@@ -103,15 +348,27 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
channels=output_channels,
kernel_size=kernel_size,
dilations=dilations,
act_fn=act_fn,
leaky_relu_negative_slope=leaky_relu_negative_slope,
antialias=antialias,
antialias_ratio=antialias_ratio,
antialias_kernel_size=antialias_kernel_size,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
if act_fn == "snakebeta" or act_fn == "snake":
# Always use antialiasing
act_out = SnakeBeta(channels=output_channels, use_beta=True)
self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
elif act_fn == "leaky_relu":
# NOTE: does NOT use self.negative_slope, following the original code
self.act_out = nn.LeakyReLU()
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
@@ -139,7 +396,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
if self.act_fn == "leaky_relu":
# Other activations are inside each upsampling block
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
@@ -149,10 +408,190 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.act_out(hidden_states)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
if self.final_act_fn == "tanh":
hidden_states = torch.tanh(hidden_states)
elif self.final_act_fn == "clamp":
hidden_states = torch.clamp(hidden_states, -1, 1)
return hidden_states
class CausalSTFT(nn.Module):
"""
Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases
multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact
buffers should be loaded from the checkpoint in bfloat16.
"""
def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512):
super().__init__()
self.hop_length = hop_length
self.window_length = window_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if waveform.ndim == 2:
waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples]
left_pad = max(0, self.window_length - self.hop_length) # causal: left-only
waveform = F.pad(waveform, (left_pad, 0))
spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real**2 + imag**2)
phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""
Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be
loaded from the checkpoint in bfloat16.
"""
def __init__(
self,
filter_length: int = 512,
hop_length: int = 80,
window_length: int = 512,
num_mel_channels: int = 64,
):
super().__init__()
self.stft_fn = CausalSTFT(filter_length, hop_length, window_length)
num_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True)
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
magnitude, phase = self.stft_fn(waveform)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class LTX2VocoderWithBWE(ModelMixin, ConfigMixin):
"""
LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the
BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same
architecture as the original vocoder.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1536,
out_channels: int = 2,
upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4],
upsample_factors: list[int] = [5, 2, 2, 2, 2, 2],
resnet_kernel_sizes: list[int] = [3, 7, 11],
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
act_fn: str = "snakebeta",
leaky_relu_negative_slope: float = 0.1,
antialias: bool = True,
antialias_ratio: int = 2,
antialias_kernel_size: int = 12,
final_act_fn: str | None = None,
final_bias: bool = False,
bwe_in_channels: int = 128,
bwe_hidden_channels: int = 512,
bwe_out_channels: int = 2,
bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4],
bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2],
bwe_resnet_kernel_sizes: list[int] = [3, 7, 11],
bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
bwe_act_fn: str = "snakebeta",
bwe_leaky_relu_negative_slope: float = 0.1,
bwe_antialias: bool = True,
bwe_antialias_ratio: int = 2,
bwe_antialias_kernel_size: int = 12,
bwe_final_act_fn: str | None = None,
bwe_final_bias: bool = False,
filter_length: int = 512,
hop_length: int = 80,
window_length: int = 512,
num_mel_channels: int = 64,
input_sampling_rate: int = 16000,
output_sampling_rate: int = 48000,
):
super().__init__()
self.vocoder = LTX2Vocoder(
in_channels=in_channels,
hidden_channels=hidden_channels,
out_channels=out_channels,
upsample_kernel_sizes=upsample_kernel_sizes,
upsample_factors=upsample_factors,
resnet_kernel_sizes=resnet_kernel_sizes,
resnet_dilations=resnet_dilations,
act_fn=act_fn,
leaky_relu_negative_slope=leaky_relu_negative_slope,
antialias=antialias,
antialias_ratio=antialias_ratio,
antialias_kernel_size=antialias_kernel_size,
final_act_fn=final_act_fn,
final_bias=final_bias,
output_sampling_rate=input_sampling_rate,
)
self.bwe_generator = LTX2Vocoder(
in_channels=bwe_in_channels,
hidden_channels=bwe_hidden_channels,
out_channels=bwe_out_channels,
upsample_kernel_sizes=bwe_upsample_kernel_sizes,
upsample_factors=bwe_upsample_factors,
resnet_kernel_sizes=bwe_resnet_kernel_sizes,
resnet_dilations=bwe_resnet_dilations,
act_fn=bwe_act_fn,
leaky_relu_negative_slope=bwe_leaky_relu_negative_slope,
antialias=bwe_antialias,
antialias_ratio=bwe_antialias_ratio,
antialias_kernel_size=bwe_antialias_kernel_size,
final_act_fn=bwe_final_act_fn,
final_bias=bwe_final_bias,
output_sampling_rate=output_sampling_rate,
)
self.mel_stft = MelSTFT(
filter_length=filter_length,
hop_length=hop_length,
window_length=window_length,
num_mel_channels=num_mel_channels,
)
self.resampler = UpSample1d(
ratio=output_sampling_rate // input_sampling_rate,
window_type="hann",
persistent=False,
)
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
# 1. Run stage 1 vocoder to get low sampling rate waveform
x = self.vocoder(mel_spec)
batch_size, num_channels, num_samples = x.shape
# Pad to exact multiple of hop_length for exact mel frame count
remainder = num_samples % self.config.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
# 2. Compute mel spectrogram on vocoder output
mel, _, _, _ = self.mel_stft(x.flatten(0, 1))
mel = mel.unflatten(0, (-1, num_channels))
# 3. Run bandwidth extender (BWE) on new mel spectrogram
mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins]
residual = self.bwe_generator(mel_for_bwe)
# 4. Residual connection with resampler
skip = self.resampler(x)
waveform = torch.clamp(residual + skip, -1, 1)
output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate
waveform = waveform[..., :output_samples]
return waveform

View File

@@ -35,6 +35,8 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
# uncondition for scaling
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
self.post_init()
def forward(self, pixel_values, return_uncond_vector=False):
clip_output = self.model(pixel_values=pixel_values)
latent_states = clip_output.pooler_output

View File

@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import DPMSolverMultistepScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: DPMSolverMultistepScheduler,
):
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import SanaLoraLoaderMixin
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import (
BACKENDS_MAPPING,
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
The tokenizer used to tokenize the prompt.
text_encoder ([`Gemma2PreTrainedModel`]):
Text encoder model to encode the input prompts.
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
transformer ([`SanaVideoTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
self,
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
text_encoder: Gemma2PreTrainedModel,
vae: AutoencoderDC | AutoencoderKLWan,
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
transformer: SanaVideoTransformer3DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
):
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
if getattr(self, "vae", None):
if isinstance(self.vae, AutoencoderKLLTX2Video):
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
else:
self.vae_scale_factor_temporal = 4
self.vae_scale_factor_spatial = 8
self.vae_scale_factor = self.vae_scale_factor_spatial
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, -1, 1, 1, 1)
.to(image_latents.device, image_latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
image_latents.device, image_latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
_latents_mean = self.vae.latents_mean
_latents_std = self.vae.latents_std
elif isinstance(self.vae, AutoencoderKLWan):
_latents_mean = torch.tensor(self.vae.config.latents_mean)
_latents_std = torch.tensor(self.vae.config.latents_std)
else:
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
image_latents = (image_latents - latents_mean) * latents_std
latents[:, :, 0:1] = image_latents.to(dtype)
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if is_torch_version(">=", "2.5.0")
else torch_accelerator_module.OutOfMemoryError
)
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
latents.device, latents.dtype
)
if isinstance(self.vae, AutoencoderKLLTX2Video):
latents_mean = self.vae.latents_mean
latents_std = self.vae.latents_std
z_dim = self.vae.config.latent_channels
elif isinstance(self.vae, AutoencoderKLWan):
latents_mean = torch.tensor(self.vae.config.latents_mean)
latents_std = torch.tensor(self.vae.config.latents_std)
z_dim = self.vae.config.z_dim
else:
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
z_dim = latents.shape[1]
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
try:
video = self.vae.decode(latents, return_dict=False)[0]

View File

@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import (
floats_tensor,
is_flaky,
require_peft_backend,
require_peft_version_greater,
skip_mps,
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"base_dim": 3,
"z_dim": 4,
"dim_mult": [1, 1, 1, 1],
"latents_mean": torch.randn(4).numpy().tolist(),
"latents_std": torch.randn(4).numpy().tolist(),
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}

View File

@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
# Override to set a more lenient max diff threshold.
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
def test_save_load_float16(self):
super().test_save_load_float16(expected_max_diff=0.03)
pass
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):

View File

@@ -171,6 +171,7 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
"processor": None,
}
return components

View File

@@ -171,6 +171,7 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
"tokenizer": tokenizer,
"connectors": connectors,
"vocoder": vocoder,
"processor": None,
}
return components