mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-30 12:27:50 +08:00
Compare commits
18 Commits
make-tiny-
...
ltx23-pari
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a215eca60 | ||
|
|
c3c9555db8 | ||
|
|
5dde9fc179 | ||
|
|
b9761ce5a2 | ||
|
|
52558b45d8 | ||
|
|
c02c17c6ee | ||
|
|
a9855c4204 | ||
|
|
0b35834351 | ||
|
|
522b523e40 | ||
|
|
e9b9f25f67 | ||
|
|
32b4cfc81c | ||
|
|
a13e5cf9fc | ||
|
|
072d15ee42 | ||
|
|
67613369bb | ||
|
|
0c01a4b5e2 | ||
|
|
8e4b5607ed | ||
|
|
c6f72ad2f6 | ||
|
|
11a3284cee |
@@ -24,54 +24,10 @@ Strive to write code as simple and explicit as possible.
|
||||
|
||||
### Models
|
||||
- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
|
||||
- Try to not introduce graph breaks as much as possible for better compatibility with `torch.compile`. For example, DO NOT arbitrarily insert operations from NumPy in the forward implementations.
|
||||
- Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
|
||||
- See the **model-integration** skill for the attention pattern, pipeline rules, test setup instructions, and other important details.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
## Skills
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Pipeline
|
||||
- All pipelines must inherit from `DiffusionPipeline`. Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
- DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline` which will be a part of the core codebase (`src`).
|
||||
|
||||
|
||||
### Tests
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
Task-specific guides live in `.ai/skills/` and are loaded on demand by AI agents.
|
||||
Available skills: **model-integration** (adding/converting pipelines), **parity-testing** (debugging numerical parity).
|
||||
|
||||
167
.ai/skills/model-integration/SKILL.md
Normal file
167
.ai/skills/model-integration/SKILL.md
Normal file
@@ -0,0 +1,167 @@
|
||||
---
|
||||
name: integrating-models
|
||||
description: >
|
||||
Use when adding a new model or pipeline to diffusers, setting up file
|
||||
structure for a new model, converting a pipeline to modular format, or
|
||||
converting weights for a new version of an already-supported model.
|
||||
---
|
||||
|
||||
## Goal
|
||||
|
||||
Integrate a new model into diffusers end-to-end. The overall flow:
|
||||
|
||||
1. **Gather info** — ask the user for the reference repo, setup guide, a runnable inference script, and other objectives such as standard vs modular.
|
||||
2. **Confirm the plan** — once you have everything, tell the user exactly what you'll do: e.g. "I'll integrate model X with pipeline Y into diffusers based on your script. I'll run parity tests (model-level and pipeline-level) using the `parity-testing` skill to verify numerical correctness against the reference."
|
||||
3. **Implement** — write the diffusers code (model, pipeline, scheduler if needed), convert weights, register in `__init__.py`.
|
||||
4. **Parity test** — use the `parity-testing` skill to verify component and e2e parity against the reference implementation.
|
||||
5. **Deliver a unit test** — provide a self-contained test script that runs the diffusers implementation, checks numerical output (np allclose), and saves an image/video for visual verification. This is what the user runs to confirm everything works.
|
||||
|
||||
Work one workflow at a time — get it to full parity before moving on.
|
||||
|
||||
## Setup — gather before starting
|
||||
|
||||
Before writing any code, gather info in this order:
|
||||
|
||||
1. **Reference repo** — ask for the github link. If they've already set it up locally, ask for the path. Otherwise, ask what setup steps are needed (install deps, download checkpoints, set env vars, etc.) and run through them before proceeding.
|
||||
2. **Inference script** — ask for a runnable end-to-end script for a basic workflow first (e.g. T2V). Then ask what other workflows they want to support (I2V, V2V, etc.) and agree on the full implementation order together.
|
||||
3. **Standard vs modular** — standard pipelines, modular, or both?
|
||||
|
||||
Use `AskUserQuestion` with structured choices for step 3 when the options are known.
|
||||
|
||||
## Standard Pipeline Integration
|
||||
|
||||
### File structure for a new model
|
||||
|
||||
```
|
||||
src/diffusers/
|
||||
models/transformers/transformer_<model>.py # The core model
|
||||
schedulers/scheduling_<model>.py # If model needs a custom scheduler
|
||||
pipelines/<model>/
|
||||
__init__.py
|
||||
pipeline_<model>.py # Main pipeline
|
||||
pipeline_<model>_<variant>.py # Variant pipelines (e.g. pyramid, distilled)
|
||||
pipeline_output.py # Output dataclass
|
||||
loaders/lora_pipeline.py # LoRA mixin (add to existing file)
|
||||
|
||||
tests/
|
||||
models/transformers/test_models_transformer_<model>.py
|
||||
pipelines/<model>/test_<model>.py
|
||||
lora/test_lora_layers_<model>.py
|
||||
|
||||
docs/source/en/api/
|
||||
pipelines/<model>.md
|
||||
models/<model>_transformer3d.md # or appropriate name
|
||||
```
|
||||
|
||||
### Integration checklist
|
||||
|
||||
- [ ] Implement transformer model with `from_pretrained` support
|
||||
- [ ] Implement or reuse scheduler
|
||||
- [ ] Implement pipeline(s) with `__call__` method
|
||||
- [ ] Add LoRA support if applicable
|
||||
- [ ] Register all classes in `__init__.py` files (lazy imports)
|
||||
- [ ] Write unit tests (model, pipeline, LoRA)
|
||||
- [ ] Write docs
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test parity with reference implementation (see `parity-testing` skill)
|
||||
|
||||
### Attention pattern
|
||||
|
||||
Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.
|
||||
|
||||
```python
|
||||
# transformer_mymodel.py
|
||||
|
||||
class MyModelAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __call__(self, attn, hidden_states, attention_mask=None, ...):
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
# reshape, apply rope, etc.
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query, key, value,
|
||||
attn_mask=attention_mask,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
return attn.to_out[0](hidden_states)
|
||||
|
||||
|
||||
class MyModelAttention(nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = MyModelAttnProcessor
|
||||
_available_processors = [MyModelAttnProcessor]
|
||||
|
||||
def __init__(self, query_dim, heads=8, dim_head=64, ...):
|
||||
super().__init__()
|
||||
self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
|
||||
self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
|
||||
self.set_processor(MyModelAttnProcessor())
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, **kwargs):
|
||||
return self.processor(self, hidden_states, attention_mask, **kwargs)
|
||||
```
|
||||
|
||||
Consult the implementations in `src/diffusers/models/transformers/` if you need further references.
|
||||
|
||||
### Implementation rules
|
||||
|
||||
1. **Don't combine structural changes with behavioral changes.** Restructuring code to fit diffusers APIs (ModelMixin, ConfigMixin, etc.) is unavoidable. But don't also "improve" the algorithm, refactor computation order, or rename internal variables for aesthetics. Keep numerical logic as close to the reference as possible, even if it looks unclean. For standard → modular, this is stricter: copy loop logic verbatim and only restructure into blocks. Clean up in a separate commit after parity is confirmed.
|
||||
2. **Pipelines must inherit from `DiffusionPipeline`.** Consult implementations in `src/diffusers/pipelines` in case you need references.
|
||||
3. **Don't subclass an existing pipeline for a variant.** DO NOT use an existing pipeline class (e.g., `FluxPipeline`) to override another pipeline (e.g., `FluxImg2ImgPipeline`) which will be a part of the core codebase (`src`).
|
||||
|
||||
### Test setup
|
||||
|
||||
- Slow tests gated with `@slow` and `RUN_SLOW=1`
|
||||
- All model-level tests must use the `BaseModelTesterConfig`, `ModelTesterMixin`, `MemoryTesterMixin`, `AttentionTesterMixin`, `LoraTesterMixin`, and `TrainingTesterMixin` classes initially to write the tests. Any additional tests should be added after discussions with the maintainers. Use `tests/models/transformers/test_models_transformer_flux.py` as a reference.
|
||||
|
||||
### Common diffusers conventions
|
||||
|
||||
- Pipelines inherit from `DiffusionPipeline`
|
||||
- Models use `ModelMixin` with `register_to_config` for config serialization
|
||||
- Schedulers use `SchedulerMixin` with `ConfigMixin`
|
||||
- Use `@torch.no_grad()` on pipeline `__call__`
|
||||
- Support `output_type="latent"` for skipping VAE decode
|
||||
- Support `generator` parameter for reproducibility
|
||||
- Use `self.progress_bar(timesteps)` for progress tracking
|
||||
|
||||
## Gotchas
|
||||
|
||||
1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.
|
||||
|
||||
2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.
|
||||
|
||||
3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.
|
||||
|
||||
4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.
|
||||
|
||||
5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.
|
||||
|
||||
6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.
|
||||
|
||||
7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.
|
||||
|
||||
8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.
|
||||
|
||||
---
|
||||
|
||||
## Modular Pipeline Conversion
|
||||
|
||||
See [modular-conversion.md](modular-conversion.md) for the full guide on converting standard pipelines to modular format, including block types, build order, guider abstraction, and conversion checklist.
|
||||
|
||||
---
|
||||
|
||||
## Weight Conversion Tips
|
||||
|
||||
<!-- TODO: Add concrete examples as we encounter them. Common patterns to watch for:
|
||||
- Fused QKV weights that need splitting into separate Q, K, V
|
||||
- Scale/shift ordering differences (reference stores [shift, scale], diffusers expects [scale, shift])
|
||||
- Weight transpositions (linear stored as transposed conv, or vice versa)
|
||||
- Interleaved head dimensions that need reshaping
|
||||
- Bias terms absorbed into different layers
|
||||
Add each with a before/after code snippet showing the conversion. -->
|
||||
152
.ai/skills/model-integration/modular-conversion.md
Normal file
152
.ai/skills/model-integration/modular-conversion.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# Modular Pipeline Conversion Reference
|
||||
|
||||
## When to use
|
||||
|
||||
Modular pipelines break a monolithic `__call__` into composable blocks. Convert when:
|
||||
- The model supports multiple workflows (T2V, I2V, V2V, etc.)
|
||||
- Users need to swap guidance strategies (CFG, CFG-Zero*, PAG)
|
||||
- You want to share blocks across pipeline variants
|
||||
|
||||
## File structure
|
||||
|
||||
```
|
||||
src/diffusers/modular_pipelines/<model>/
|
||||
__init__.py # Lazy imports
|
||||
modular_pipeline.py # Pipeline class (tiny, mostly config)
|
||||
encoders.py # Text encoder + image/video VAE encoder blocks
|
||||
before_denoise.py # Pre-denoise setup blocks
|
||||
denoise.py # The denoising loop blocks
|
||||
decoders.py # VAE decode block
|
||||
modular_blocks_<model>.py # Block assembly (AutoBlocks)
|
||||
```
|
||||
|
||||
## Block types decision tree
|
||||
|
||||
```
|
||||
Is this a single operation?
|
||||
YES -> ModularPipelineBlocks (leaf block)
|
||||
|
||||
Does it run multiple blocks in sequence?
|
||||
YES -> SequentialPipelineBlocks
|
||||
Does it iterate (e.g. chunk loop)?
|
||||
YES -> LoopSequentialPipelineBlocks
|
||||
|
||||
Does it choose ONE block based on which input is present?
|
||||
Is the selection 1:1 with trigger inputs?
|
||||
YES -> AutoPipelineBlocks (simple trigger mapping)
|
||||
NO -> ConditionalPipelineBlocks (custom select_block method)
|
||||
```
|
||||
|
||||
## Build order (easiest first)
|
||||
|
||||
1. `decoders.py` -- Takes latents, runs VAE decode, returns images/videos
|
||||
2. `encoders.py` -- Takes prompt, returns prompt_embeds. Add image/video VAE encoder if needed
|
||||
3. `before_denoise.py` -- Timesteps, latent prep, noise setup. Each logical operation = one block
|
||||
4. `denoise.py` -- The hardest. Convert guidance to guider abstraction
|
||||
|
||||
## Key pattern: Guider abstraction
|
||||
|
||||
Original pipeline has guidance baked in:
|
||||
```python
|
||||
for i, t in enumerate(timesteps):
|
||||
noise_pred = self.transformer(latents, prompt_embeds, ...)
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(latents, negative_prompt_embeds, ...)
|
||||
noise_pred = noise_uncond + scale * (noise_pred - noise_uncond)
|
||||
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||||
```
|
||||
|
||||
Modular pipeline separates concerns:
|
||||
```python
|
||||
guider_inputs = {
|
||||
"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds),
|
||||
}
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
components.guider.set_state(step=i, num_inference_steps=num_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
for batch in guider_state:
|
||||
components.guider.prepare_models(components.transformer)
|
||||
cond_kwargs = {k: getattr(batch, k) for k in guider_inputs}
|
||||
context_name = getattr(batch, components.guider._identifier_key)
|
||||
with components.transformer.cache_context(context_name):
|
||||
batch.noise_pred = components.transformer(
|
||||
hidden_states=latents, timestep=timestep,
|
||||
return_dict=False, **cond_kwargs, **shared_kwargs,
|
||||
)[0]
|
||||
components.guider.cleanup_models(components.transformer)
|
||||
|
||||
noise_pred = components.guider(guider_state)[0]
|
||||
latents = components.scheduler.step(noise_pred, t, latents, generator=generator)[0]
|
||||
```
|
||||
|
||||
## Key pattern: Chunk loops for video models
|
||||
|
||||
Use `LoopSequentialPipelineBlocks` for outer loop:
|
||||
```python
|
||||
class ChunkDenoiseStep(LoopSequentialPipelineBlocks):
|
||||
block_classes = [PrepareChunkStep, NoiseGenStep, DenoiseInnerStep, UpdateStep]
|
||||
```
|
||||
|
||||
Note: blocks inside `LoopSequentialPipelineBlocks` receive `(components, block_state, k)` where `k` is the loop iteration index.
|
||||
|
||||
## Key pattern: Workflow selection
|
||||
|
||||
```python
|
||||
class AutoDenoise(ConditionalPipelineBlocks):
|
||||
block_classes = [V2VDenoiseStep, I2VDenoiseStep, T2VDenoiseStep]
|
||||
block_trigger_inputs = ["video_latents", "image_latents"]
|
||||
default_block_name = "text2video"
|
||||
```
|
||||
|
||||
## Standard InputParam/OutputParam templates
|
||||
|
||||
```python
|
||||
# Inputs
|
||||
InputParam.template("prompt") # str, required
|
||||
InputParam.template("negative_prompt") # str, optional
|
||||
InputParam.template("image") # PIL.Image, optional
|
||||
InputParam.template("generator") # torch.Generator, optional
|
||||
InputParam.template("num_inference_steps") # int, default=50
|
||||
InputParam.template("latents") # torch.Tensor, optional
|
||||
|
||||
# Outputs
|
||||
OutputParam.template("prompt_embeds")
|
||||
OutputParam.template("negative_prompt_embeds")
|
||||
OutputParam.template("image_latents")
|
||||
OutputParam.template("latents")
|
||||
OutputParam.template("videos")
|
||||
OutputParam.template("images")
|
||||
```
|
||||
|
||||
## ComponentSpec patterns
|
||||
|
||||
```python
|
||||
# Heavy models - loaded from pretrained
|
||||
ComponentSpec("transformer", YourTransformerModel)
|
||||
ComponentSpec("vae", AutoencoderKL)
|
||||
|
||||
# Lightweight objects - created inline from config
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 7.5}),
|
||||
default_creation_method="from_config"
|
||||
)
|
||||
```
|
||||
|
||||
## Conversion checklist
|
||||
|
||||
- [ ] Read original pipeline's `__call__` end-to-end, map stages
|
||||
- [ ] Write test scripts (reference + target) with identical seeds
|
||||
- [ ] Create file structure under `modular_pipelines/<model>/`
|
||||
- [ ] Write decoder block (simplest)
|
||||
- [ ] Write encoder blocks (text, image, video)
|
||||
- [ ] Write before_denoise blocks (timesteps, latent prep, noise)
|
||||
- [ ] Write denoise block with guider abstraction (hardest)
|
||||
- [ ] Create pipeline class with `default_blocks_name`
|
||||
- [ ] Assemble blocks in `modular_blocks_<model>.py`
|
||||
- [ ] Wire up `__init__.py` with lazy imports
|
||||
- [ ] Run `make style` and `make quality`
|
||||
- [ ] Test all workflows for parity with reference
|
||||
216
.ai/skills/parity-testing/SKILL.md
Normal file
216
.ai/skills/parity-testing/SKILL.md
Normal file
@@ -0,0 +1,216 @@
|
||||
---
|
||||
name: testing-parity
|
||||
description: >
|
||||
Use when debugging or verifying numerical parity between pipeline
|
||||
implementations (e.g., research repo vs diffusers, standard vs modular).
|
||||
Also relevant when outputs look wrong — washed out, pixelated, or have
|
||||
visual artifacts — as these are usually parity bugs.
|
||||
---
|
||||
|
||||
## Setup — gather before starting
|
||||
|
||||
Before writing any test code, gather:
|
||||
|
||||
1. **Which two implementations** are being compared (e.g. research repo → diffusers, standard → modular, or research → modular). Use `AskUserQuestion` with structured choices if not already clear.
|
||||
2. **Two equivalent runnable scripts** — one for each implementation, both expected to produce identical output given the same inputs. These scripts define what "parity" means concretely.
|
||||
3. **Test directory**: Ask the user if they have a preferred directory for parity test scripts and artifacts. If not, create `parity-tests/` at the repo root.
|
||||
4. **Lab book**: Ask the user if they want to maintain a `lab_book.md` in the test directory to track findings, fixes, and experiment results across sessions. This is especially useful for multi-session debugging where context gets lost.
|
||||
|
||||
When invoked from the `model-integration` skill, you already have context: the reference script comes from step 2 of setup, and the diffusers script is the one you just wrote. You just need to make sure both scripts are runnable and use the same inputs/seed/params.
|
||||
|
||||
## Phase 1: CPU/float32 parity (always run)
|
||||
|
||||
### Component parity — test as you build
|
||||
|
||||
Test each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3.
|
||||
|
||||
Test freshly converted checkpoints and saved checkpoints.
|
||||
- **Fresh**: convert from checkpoint weights, compare against reference (catches conversion bugs)
|
||||
- **Saved**: load from saved model on disk, compare against reference (catches stale saves)
|
||||
|
||||
Keep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values.
|
||||
|
||||
**Write a model interface mapping** as you test each component. This documents every input difference between reference and diffusers models — format, dtype, shape, who computes what. Save it in the test directory (e.g., `parity-tests/model_interface_mapping.md`). This is critical: during pipeline testing, you MUST reference this mapping to verify the pipeline passes inputs in the correct format. Without it, you'll waste time rediscovering differences you already found.
|
||||
|
||||
Example mapping (from LTX-2.3):
|
||||
```markdown
|
||||
| Input | Reference | Diffusers | Notes |
|
||||
|---|---|---|---|
|
||||
| timestep | per-token bf16 sigma, scaled by 1000 internally | passed as sigma*1000 | shape (B,S) not (B,) |
|
||||
| sigma (prompt_adaln) | raw f32 sigma, scaled internally | passed as sigma*1000 in f32 | NOT bf16 |
|
||||
| positions/coords | computed inside model preprocessor | passed as kwarg video_coords | cast to model dtype |
|
||||
| cross-attn timestep | always cross_modality.sigma | always audio_sigma | not conditional |
|
||||
| encoder_attention_mask | None (no mask) | None or all-ones | all-ones triggers different SDPA kernel |
|
||||
| RoPE | computed in model dtype (no upcast) | must match — no float32 upcast | cos/sin cast to input dtype |
|
||||
| output format | X0Model returns x0 | transformer returns velocity | v→x0: (sample - vel * sigma) |
|
||||
| audio output | .squeeze(0).float() | must match | (2,N) float32 not (1,2,N) bf16 |
|
||||
```
|
||||
|
||||
Template -- one self-contained script per component, reference and diffusers side-by-side:
|
||||
```python
|
||||
@torch.inference_mode()
|
||||
def test_my_component(mode="fresh", model_path=None):
|
||||
# 1. Deterministic input
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
x = torch.randn(1, 3, 64, 64, generator=gen, dtype=torch.float32)
|
||||
|
||||
# 2. Reference: load from checkpoint, run, free
|
||||
ref_model = ReferenceModel.from_config(config)
|
||||
ref_model.load_state_dict(load_weights("prefix"), strict=True)
|
||||
ref_model = ref_model.float().eval()
|
||||
ref_out = ref_model(x).clone()
|
||||
del ref_model
|
||||
|
||||
# 3. Diffusers: fresh (convert weights) or saved (from_pretrained)
|
||||
if mode == "fresh":
|
||||
diff_model = convert_my_component(load_weights("prefix"))
|
||||
else:
|
||||
diff_model = DiffusersModel.from_pretrained(model_path, torch_dtype=torch.float32)
|
||||
diff_model = diff_model.float().eval()
|
||||
diff_out = diff_model(x)
|
||||
del diff_model
|
||||
|
||||
# 4. Compare in same script -- no saving to disk
|
||||
max_diff = (ref_out - diff_out).abs().max().item()
|
||||
assert max_diff < 1e-3, f"FAIL: max_diff={max_diff:.2e}"
|
||||
```
|
||||
Key points: (a) both reference and diffusers component in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model.
|
||||
|
||||
### Pipeline stage tests — encode, decode, then denoise
|
||||
|
||||
Use the capture-inject checkpoint method (see [checkpoint-mechanism.md](checkpoint-mechanism.md)) to test each pipeline stage independently. This methodology is the same for both CPU/float32 and GPU/bf16.
|
||||
|
||||
Before writing pipeline tests, **review the model interface mapping** from the component test phase and verify them. The mapping tells you which differences between the two models are expected (e.g., reference expects raw sigma but diffusers expects sigma*1000). Without it, you'll waste time investigating differences that are by design, not bugs.
|
||||
|
||||
First, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match.
|
||||
|
||||
**Stage test order:**
|
||||
|
||||
- **`encode`** (test first): Stop both pipelines at `"preloop"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs.
|
||||
- **`decode`** (test second): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the **final output**. Feed those same post-loop latents through the diffusers decode path. Compare the **final output format** -- not raw tensors, but what the user actually gets:
|
||||
- **Image**: compare PIL.Image pixels
|
||||
- **Video**: compare through the pipeline's export function (e.g. `encode_video`)
|
||||
- **Video+Audio**: compare video frames AND audio waveform through `encode_video`
|
||||
- This catches postprocessing bugs like float→uint8 rounding, audio format, and codec settings.
|
||||
- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps. For float32, stop after 2 loop iterations using `after_step_1` (don't set `num_steps=2` -- that produces unrealistic sigma schedules). For bf16, run ALL steps (see Phase 2).
|
||||
|
||||
Start with coarse checkpoints (`after_step_{i}` — just the denoised latents at each step). If a step diverges, place finer checkpoints within that step (e.g. before/after model call, after CFG, after scheduler step). If the divergence is inside the model forward call, use PyTorch forward hooks (`register_forward_hook`) to capture intermediate outputs from sub-modules (e.g., attention output, feed-forward output) and compare them between the two models to find the first diverging operation.
|
||||
|
||||
```python
|
||||
# Encode stage -- stop before the loop, compare ALL inputs:
|
||||
ref_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
|
||||
run_reference_pipeline(ref_ckpts)
|
||||
ref_data = ref_ckpts["preloop"].data
|
||||
|
||||
diff_ckpts = {"preloop": Checkpoint(save=True, stop=True)}
|
||||
run_diffusers_pipeline(diff_ckpts)
|
||||
diff_data = diff_ckpts["preloop"].data
|
||||
|
||||
# Compare EVERY variable consumed by the denoise loop:
|
||||
compare_tensors("latents", ref_data["latents"], diff_data["latents"])
|
||||
compare_tensors("sigmas", ref_data["sigmas"], diff_data["sigmas"])
|
||||
compare_tensors("prompt_embeds", ref_data["prompt_embeds"], diff_data["prompt_embeds"])
|
||||
# ... every single tensor the transformer forward() will receive
|
||||
```
|
||||
|
||||
### E2E visual — once stages pass
|
||||
|
||||
Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, Phase 1 is done.
|
||||
|
||||
If CPU/float32 stage tests all pass and E2E outputs are identical → Phase 1 is done, move on.
|
||||
|
||||
If E2E outputs are NOT identical despite stage tests passing, **ask the user**: "CPU/float32 parity passes at the stage level but E2E output differs. The output in bf16/GPU may look slightly different from the reference due to precision casting, but the quality should be the same. Do you want to just vibe-check the output quality, or do you need 1:1 identical output with the reference in bf16?"
|
||||
|
||||
- If the user says quality looks fine → **done**.
|
||||
- If the user needs 1:1 identical output in bf16 → Phase 2.
|
||||
|
||||
## Phase 2: GPU/bf16 parity (optional — only if user needs 1:1 output)
|
||||
|
||||
If CPU/float32 passes, the algorithm is correct. bf16 differences are from precision casting (e.g., float32 vs bf16 in RoPE, CFG arithmetic order, scheduler intermediates), not logic bugs. These can make the output look slightly different from the reference even though the quality is identical. Phase 2 eliminates these casting differences so the diffusers output is **bit-identical** to the reference in bf16.
|
||||
|
||||
Phase 2 uses the **exact same stage test methodology** as Phase 1 (encode → decode → denoise with progressive checkpoint refinement), with two differences:
|
||||
|
||||
1. **dtype=bf16, device=GPU** instead of float32/CPU
|
||||
2. **Run the FULL denoising loop** (all steps, not just 2) — bf16 casting differences accumulate over steps and may only manifest after many iterations
|
||||
|
||||
See [pitfalls.md](pitfalls.md) #19-#27 for the catalog of bf16-specific gotchas.
|
||||
|
||||
## Debugging technique: Injection for root-cause isolation
|
||||
|
||||
When stage tests show divergence, **inject a known-good tensor from one pipeline into the other** to test whether the remaining code is correct.
|
||||
|
||||
The principle: if you suspect input X is the root cause of divergence in stage S:
|
||||
1. Run the reference pipeline and capture X
|
||||
2. Run the diffusers pipeline but **replace** its X with the reference's X (via checkpoint load)
|
||||
3. Compare outputs of stage S
|
||||
|
||||
If outputs now match: X was the root cause. If they still diverge: the bug is in the stage logic itself, not in X.
|
||||
|
||||
| What you're testing | What you inject | Where you inject |
|
||||
|---|---|---|
|
||||
| Is the decode stage correct? | Post-loop latents from reference | Before decode |
|
||||
| Is the denoise loop correct? | Pre-loop latents from reference | Before the loop |
|
||||
| Is step N correct? | Post-step-(N-1) latents from reference | Before step N |
|
||||
|
||||
**Per-step accumulation tracing**: When injection confirms the loop is correct but you want to understand *how* a small initial difference compounds, capture `after_step_{i}` for every step and plot the max_diff curve. A healthy curve stays bounded; an exponential blowup in later steps points to an amplification mechanism (see Pitfall #13 in [pitfalls.md](pitfalls.md)).
|
||||
|
||||
## Debugging technique: Visual comparison via frame extraction
|
||||
|
||||
For video pipelines, numerical metrics alone can be misleading. Extract and view individual frames:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
def extract_frames(video_np, frame_indices):
|
||||
"""video_np: (frames, H, W, 3) float array in [0, 1]"""
|
||||
for idx in frame_indices:
|
||||
frame = (video_np[idx] * 255).clip(0, 255).astype(np.uint8)
|
||||
img = Image.fromarray(frame)
|
||||
img.save(f"frame_{idx}.png")
|
||||
|
||||
# Compare specific frames from both pipelines
|
||||
extract_frames(ref_video, [0, 60, 120])
|
||||
extract_frames(diff_video, [0, 60, 120])
|
||||
```
|
||||
|
||||
## Testing rules
|
||||
|
||||
1. **Never use reference code in the diffusers test path.** Each side must use only its own code.
|
||||
2. **Never monkey-patch model internals in tests.** Do not replace `model.forward` or patch internal methods.
|
||||
3. **Debugging instrumentation must be non-destructive.** Checkpoint captures for debugging are fine, but must not alter control flow or outputs.
|
||||
4. **Prefer CPU/float32 for numerical comparison when practical.** Float32 avoids bfloat16 precision noise that obscures real bugs. But for large models (22B+), GPU/bfloat16 with `enable_model_cpu_offload()` is necessary -- use relaxed tolerances and cosine similarity as a secondary metric.
|
||||
5. **Test both fresh conversion AND saved model.** Fresh catches conversion logic bugs; saved catches stale/corrupted weights from previous runs.
|
||||
6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values. A 30-second config diff prevents hours of debugging based on wrong assumptions.
|
||||
7. **Never modify cached/downloaded model configs directly.** Don't edit files in `~/.cache/huggingface/`. Instead, save to a local directory or open a PR on the upstream repo.
|
||||
8. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive.
|
||||
9. **Don't contaminate test paths.** Each side (reference, diffusers) must use only its own code to generate outputs. For COMPARISON, save both outputs through the SAME function (so codec/format differences don't create false diffs). Example: don't use the reference's `encode_video` for one side and diffusers' for the other.
|
||||
10. **Re-test standalone model through the actual pipeline if divergence points to the model.** If pipeline stage tests show the divergence is at the model output (e.g., `cond_x0` differs despite identical inputs), re-run the model comparison using capture-inject with real pipeline-generated inputs. Standalone model tests use manually constructed kwargs which may have wrong config values, dtypes, or shapes — the pipeline generates the real ones.
|
||||
|
||||
## Comparison utilities
|
||||
|
||||
```python
|
||||
def compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e-3) -> bool:
|
||||
if a.shape != b.shape:
|
||||
print(f" FAIL {name}: shape mismatch {a.shape} vs {b.shape}")
|
||||
return False
|
||||
diff = (a.float() - b.float()).abs()
|
||||
max_diff = diff.max().item()
|
||||
mean_diff = diff.mean().item()
|
||||
cos = torch.nn.functional.cosine_similarity(
|
||||
a.float().flatten().unsqueeze(0), b.float().flatten().unsqueeze(0)
|
||||
).item()
|
||||
passed = max_diff < tol
|
||||
print(f" {'PASS' if passed else 'FAIL'} {name}: max={max_diff:.2e}, mean={mean_diff:.2e}, cos={cos:.5f}")
|
||||
return passed
|
||||
```
|
||||
Cosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance.
|
||||
|
||||
## Example scripts
|
||||
|
||||
- [examples/test_component_parity_cpu.py](examples/test_component_parity_cpu.py) — Template for CPU/float32 component parity test
|
||||
- [examples/test_e2e_bf16_parity.py](examples/test_e2e_bf16_parity.py) — Template for GPU/bf16 E2E parity test with capture-inject
|
||||
|
||||
## Gotchas
|
||||
|
||||
See [pitfalls.md](pitfalls.md) for the full list of gotchas to watch for during parity testing.
|
||||
103
.ai/skills/parity-testing/checkpoint-mechanism.md
Normal file
103
.ai/skills/parity-testing/checkpoint-mechanism.md
Normal file
@@ -0,0 +1,103 @@
|
||||
# Checkpoint Mechanism for Stage Testing
|
||||
|
||||
## Overview
|
||||
|
||||
Pipelines are monolithic `__call__` methods -- you can't just call "the encode part". The checkpoint mechanism lets you stop, save, or inject tensors at named locations inside the pipeline.
|
||||
|
||||
## The Checkpoint class
|
||||
|
||||
Add a `_checkpoints` argument to both the diffusers pipeline and the reference implementation.
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
save: bool = False # capture variables into ckpt.data
|
||||
stop: bool = False # halt pipeline after this point
|
||||
load: bool = False # inject ckpt.data into local variables
|
||||
data: dict = field(default_factory=dict)
|
||||
```
|
||||
|
||||
## Pipeline instrumentation
|
||||
|
||||
The pipeline accepts an optional `dict[str, Checkpoint]`. Place checkpoint calls at boundaries between pipeline stages -- after each encoder, before the denoising loop (capture all loop inputs), after each loop iteration, after the loop (capture final latents before decode).
|
||||
|
||||
```python
|
||||
def __call__(self, prompt, ..., _checkpoints=None):
|
||||
# --- text encoding ---
|
||||
prompt_embeds = self.text_encoder(prompt)
|
||||
_maybe_checkpoint(_checkpoints, "text_encoding", {
|
||||
"prompt_embeds": prompt_embeds,
|
||||
})
|
||||
|
||||
# --- prepare latents, sigmas, positions ---
|
||||
latents = self.prepare_latents(...)
|
||||
sigmas = self.scheduler.sigmas
|
||||
# ...
|
||||
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {
|
||||
"latents": latents,
|
||||
"sigmas": sigmas,
|
||||
"prompt_embeds": prompt_embeds,
|
||||
"prompt_attention_mask": prompt_attention_mask,
|
||||
"video_coords": video_coords,
|
||||
# capture EVERYTHING the loop needs -- every tensor the transformer
|
||||
# forward() receives. Missing even one variable here means you can't
|
||||
# tell if it's the source of divergence during denoise debugging.
|
||||
})
|
||||
|
||||
# --- denoising loop ---
|
||||
for i, t in enumerate(timesteps):
|
||||
noise_pred = self.transformer(latents, t, prompt_embeds, ...)
|
||||
latents = self.scheduler.step(noise_pred, t, latents)[0]
|
||||
|
||||
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
|
||||
"latents": latents,
|
||||
})
|
||||
|
||||
_maybe_checkpoint(_checkpoints, "post_loop", {
|
||||
"latents": latents,
|
||||
})
|
||||
|
||||
# --- decode ---
|
||||
video = self.vae.decode(latents)
|
||||
return video
|
||||
```
|
||||
|
||||
## The helper function
|
||||
|
||||
Each `_maybe_checkpoint` call does three things based on the Checkpoint's flags: `save` captures the local variables into `ckpt.data`, `load` injects pre-populated `ckpt.data` back into local variables, `stop` halts execution (raises an exception caught at the top level).
|
||||
|
||||
```python
|
||||
def _maybe_checkpoint(checkpoints, name, data):
|
||||
if not checkpoints:
|
||||
return
|
||||
ckpt = checkpoints.get(name)
|
||||
if ckpt is None:
|
||||
return
|
||||
if ckpt.save:
|
||||
ckpt.data.update(data)
|
||||
if ckpt.stop:
|
||||
raise PipelineStop # caught at __call__ level, returns None
|
||||
```
|
||||
|
||||
## Injection support
|
||||
|
||||
Add `load` support at each checkpoint where you might want to inject:
|
||||
|
||||
```python
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {"latents": latents, ...})
|
||||
|
||||
# Load support: replace local variables with injected data
|
||||
if _checkpoints:
|
||||
ckpt = _checkpoints.get("preloop")
|
||||
if ckpt is not None and ckpt.load:
|
||||
latents = ckpt.data["latents"].to(device=device, dtype=latents.dtype)
|
||||
```
|
||||
|
||||
## Key insight
|
||||
|
||||
The checkpoint dict is passed into the pipeline and mutated in-place. After the pipeline returns (or stops early), you read back `ckpt.data` to get the captured tensors. Both pipelines save under their own key names, so the test maps between them (e.g. reference `"video_state.latent"` -> diffusers `"latents"`).
|
||||
|
||||
## Memory management for large models
|
||||
|
||||
For large models, free the source pipeline's GPU memory before loading the target pipeline. Clone injected tensors to CPU, delete everything else, then run the target with `enable_model_cpu_offload()`.
|
||||
154
.ai/skills/parity-testing/pitfalls.md
Normal file
154
.ai/skills/parity-testing/pitfalls.md
Normal file
@@ -0,0 +1,154 @@
|
||||
# Complete Pitfalls Reference
|
||||
|
||||
## 1. Global CPU RNG
|
||||
`MultivariateNormal.sample()` uses the global CPU RNG, not `torch.Generator`. Must call `torch.manual_seed(seed)` before each pipeline run. A `generator=` kwarg won't help.
|
||||
|
||||
## 2. Timestep dtype
|
||||
Many transformers expect `int64` timesteps. `get_timestep_embedding` casts to float, so `745.3` and `745` produce different embeddings. Match the reference's casting.
|
||||
|
||||
## 3. Guidance parameter mapping
|
||||
Parameter names may differ: reference `zero_steps=1` (meaning `i <= 1`, 2 steps) vs target `zero_init_steps=2` (meaning `step < 2`, same thing). Check exact semantics.
|
||||
|
||||
## 4. `patch_size` in noise generation
|
||||
If noise generation depends on `patch_size` (e.g. `sample_block_noise`), it must be passed through. Missing it changes noise spatial structure.
|
||||
|
||||
## 5. Variable shadowing in nested loops
|
||||
Nested loops (stages -> chunks -> timesteps) can shadow variable names. If outer loop uses `latents` and inner loop also assigns to `latents`, scoping must match the reference.
|
||||
|
||||
## 6. Float precision differences -- don't dismiss them
|
||||
Target may compute in float32 where reference used bfloat16. Small per-element diffs (1e-3 to 1e-2) *look* harmless but can compound catastrophically over iterative processes like denoising loops (see Pitfalls #11 and #13). Before dismissing a precision difference: (a) check whether it feeds into an iterative process, (b) if so, trace the accumulation curve over all iterations to see if it stays bounded or grows exponentially. Only truly non-iterative precision diffs (e.g. in a single-pass encoder) are safe to accept.
|
||||
|
||||
## 7. Scheduler state reset between stages
|
||||
Some schedulers accumulate state (e.g. `model_outputs` in UniPC) that must be cleared between stages.
|
||||
|
||||
## 8. Component access
|
||||
Standard: `self.transformer`. Modular: `components.transformer`. Missing this causes AttributeError.
|
||||
|
||||
## 9. Guider state across stages
|
||||
In multi-stage denoising, the guider's internal state (e.g. `zero_init_steps`) may need save/restore between stages.
|
||||
|
||||
## 10. Model storage location
|
||||
NEVER store converted models in `/tmp/` -- temporary directories get wiped on restart. Always save converted checkpoints under a persistent path in the project repo (e.g. `models/ltx23-diffusers/`).
|
||||
|
||||
## 11. Noise dtype mismatch (causes washed-out output)
|
||||
|
||||
Reference code often generates noise in float32 then casts to model dtype (bfloat16) before storing:
|
||||
|
||||
```python
|
||||
noise = torch.randn(..., dtype=torch.float32, generator=gen)
|
||||
noise = noise.to(dtype=model_dtype) # bfloat16 -- values get quantized
|
||||
```
|
||||
|
||||
Diffusers pipelines may keep latents in float32 throughout the loop. The per-element difference is only ~1.5e-02, but this compounds over 30 denoising steps via 1/sigma amplification (Pitfall #13) and produces completely washed-out output.
|
||||
|
||||
**Fix**: Match the reference -- generate noise in the model's working dtype:
|
||||
```python
|
||||
latent_dtype = self.transformer.dtype # e.g. bfloat16
|
||||
latents = self.prepare_latents(..., dtype=latent_dtype, ...)
|
||||
```
|
||||
|
||||
**Detection**: Encode stage test shows initial latent max_diff of exactly ~1.5e-02. This specific magnitude is the signature of float32->bfloat16 quantization error.
|
||||
|
||||
## 12. RoPE position dtype
|
||||
|
||||
RoPE cosine/sine values are sensitive to position coordinate dtype. If reference uses bfloat16 positions but diffusers uses float32, the RoPE output diverges significantly (max_diff up to 2.0). Different modalities may use different position dtypes (e.g. video bfloat16, audio float32) -- check the reference carefully.
|
||||
|
||||
## 13. 1/sigma error amplification in Euler denoising
|
||||
|
||||
In Euler/flow-matching, the velocity formula divides by sigma: `v = (latents - pred_x0) / sigma`. As sigma shrinks from ~1.0 (step 0) to ~0.001 (step 29), errors are amplified up to 1000x. A 1.5e-02 init difference grows linearly through mid-steps, then exponentially in final steps, reaching max_diff ~6.0. This is why dtype mismatches (Pitfalls #11, #12) that seem tiny at init produce visually broken output. Use per-step accumulation tracing to diagnose.
|
||||
|
||||
## 14. Config value assumptions -- always diff, never assume
|
||||
|
||||
When debugging parity, don't assume config values match code defaults. The published model checkpoint may override defaults with different values. A wrong assumption about a single config field can send you down hours of debugging in the wrong direction.
|
||||
|
||||
**The pattern that goes wrong:**
|
||||
1. You see `param_x` has default `1` in the code
|
||||
2. The reference code also uses `param_x` with a default of `1`
|
||||
3. You assume both sides use `1` and apply a "fix" based on that
|
||||
4. But the actual checkpoint config has `param_x: 1000`, and so does the published diffusers config
|
||||
5. Your "fix" now *creates* divergence instead of fixing it
|
||||
|
||||
**Prevention -- config diff first:**
|
||||
```python
|
||||
# Reference: read from checkpoint metadata (no model loading needed)
|
||||
from safetensors import safe_open
|
||||
import json
|
||||
ref_config = json.loads(safe_open(checkpoint_path, framework="pt").metadata()["config"])
|
||||
|
||||
# Diffusers: read from model config
|
||||
from diffusers import MyModel
|
||||
diff_model = MyModel.from_pretrained(model_path, subfolder="transformer")
|
||||
diff_config = dict(diff_model.config)
|
||||
|
||||
# Compare all values
|
||||
for key in sorted(set(list(ref_config.get("transformer", {}).keys()) + list(diff_config.keys()))):
|
||||
ref_val = ref_config.get("transformer", {}).get(key, "MISSING")
|
||||
diff_val = diff_config.get(key, "MISSING")
|
||||
if ref_val != diff_val:
|
||||
print(f" DIFF {key}: ref={ref_val}, diff={diff_val}")
|
||||
```
|
||||
|
||||
Run this **before** writing any hooks, analysis code, or fixes. It takes 30 seconds and catches wrong assumptions immediately.
|
||||
|
||||
**When debugging divergence -- trace values, don't reason about them:**
|
||||
If two implementations diverge, hook the actual intermediate values at the point of divergence rather than reading code to figure out what the values "should" be. Code analysis builds on assumptions; value tracing reveals facts.
|
||||
|
||||
## 15. Decoder config mismatch (causes pixelated artifacts)
|
||||
|
||||
The upstream model config may have wrong values for decoder-specific parameters (e.g. `upsample_residual`, `upsample_type`). These control whether the decoder uses skip connections in upsampling -- getting them wrong produces severe pixelation or blocky artifacts.
|
||||
|
||||
**Detection**: Feed identical post-loop latents through both decoders. If max pixel diff is large (PSNR < 40 dB) on CPU/float32, it's a real bug, not precision noise. Trace through decoder blocks (conv_in -> mid_block -> up_blocks) to find where divergence starts.
|
||||
|
||||
**Fix**: Correct the config value. Don't edit cached files in `~/.cache/huggingface/` -- either save to a local model directory or open a PR on the upstream repo (see Testing Rule #7).
|
||||
|
||||
## 16. Incomplete injection tests -- inject ALL variables or the test is invalid
|
||||
|
||||
When doing injection tests (feeding reference tensors into the diffusers pipeline), you must inject **every** divergent input, including sigmas/timesteps. A common mistake: the preloop checkpoint saves sigmas but the injection code only loads latents and embeddings. The test then runs with different sigma schedules, making it impossible to isolate the real cause.
|
||||
|
||||
**Prevention**: After writing injection code, verify by listing every variable the injected stage consumes and checking each one is either (a) injected from reference, or (b) confirmed identical between pipelines.
|
||||
|
||||
## 17. bf16 connector/encoder divergence -- don't chase it
|
||||
|
||||
When running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector transformers) accumulate bf16 rounding noise that looks alarming (max_diff 0.3-2.7). Before investigating, re-run the component test on CPU/float32. If it passes (max_diff < 1e-4), the divergence is pure precision noise, not a code bug. Don't spend hours tracing through layers -- confirm on CPU/float32 and move on.
|
||||
|
||||
## 18. Stale test fixtures
|
||||
|
||||
When using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution.
|
||||
|
||||
## 19. RoPE float32 upcast changes bf16 output
|
||||
|
||||
If `apply_rotary_emb` upcasts input to float32 for the rotation computation (`x.float() * cos + x_rotated.float() * sin`), but the reference stays in bf16, the results differ after casting back. The float32 intermediate produces different rounding than native bf16 computation.
|
||||
|
||||
**Fix**: Remove the float32 upcast. Cast cos/sin to the input dtype instead: `cos, sin = cos.to(x.dtype), sin.to(x.dtype)`, then compute `x * cos + x_rotated * sin` in the model's native dtype.
|
||||
|
||||
## 20. CFG formula arithmetic order
|
||||
|
||||
`cond + (scale-1) * (cond - uncond)` and `uncond + scale * (cond - uncond)` are mathematically identical but produce different bf16 results because the multiplication factor (3 vs 4 for scale=4) and the base (cond vs uncond) differ. Match the reference's exact formula.
|
||||
|
||||
## 21. Scheduler float64 intermediates from numpy
|
||||
|
||||
`math.exp(mu) / (math.exp(mu) + (1/t - 1))` where `t` is a numpy float32 array promotes to float64 (because `math.exp` returns Python float64 and numpy promotes). The reference uses torch float32. Fix: compute in `torch.float32` using `torch.as_tensor(t, dtype=torch.float32)`. Same for `np.linspace` vs `torch.linspace` — use `torch.linspace` for float32-native computation.
|
||||
|
||||
## 22. Zero-dim tensor type promotion in Euler step
|
||||
|
||||
`dt * model_output` where `dt` is a 0-dim float32 tensor and `model_output` is bf16: PyTorch treats the 0-dim tensor as a "scalar" that adapts to the tensor's dtype. Result is **bf16**, not float32. The reference does `velocity.to(float32) * dt` which is float32. Fix: explicitly upcast `model_output.to(sample.dtype) * dt`.
|
||||
|
||||
## 23. Per-token vs per-batch timestep shape
|
||||
|
||||
Passing timestep as `(B,)` produces temb shape `(B, 1, D)` via the adaln. Passing `(B, S)` produces `(B, S, D)`. For T2V where all tokens share the same sigma, these are mathematically equivalent but use different CUDA kernels with different bf16 rounding. Match the reference's shape — typically per-token `(B, S)`.
|
||||
|
||||
## 24. Model config missing fields
|
||||
|
||||
The diffusers checkpoint config may be missing fields that the reference model has (e.g. `use_cross_timestep`, `prompt_modulation`). The code falls back to a default that may be wrong. Always check the ACTUAL runtime value, not the code default. Run `getattr(model.config, "field_name", "MISSING")` and compare against the reference model's config.
|
||||
|
||||
## 25. Cross-attention timestep conditional
|
||||
|
||||
The reference may always use `cross_modality.sigma` for cross-attention timestep (e.g., video cross-attn uses audio sigma), but the diffusers model may conditionally use the main timestep based on `use_cross_timestep`. If the conditional is wrong or the config field is missing, the cross-attention receives a completely different timestep — different shape `(S,)` vs `(1,)`, different value, and different sinusoidal embedding. This is a model-level bug that standalone tests miss because they pass `use_cross_timestep` manually.
|
||||
|
||||
## 26. Audio/video output format mismatch
|
||||
|
||||
The reference may return audio as `(2, N)` float32 (after `.squeeze(0).float()`), while the diffusers pipeline returns `(1, 2, N)` bf16 from the vocoder. The `_write_audio` function in `encode_video` doesn't handle 3D tensors correctly. Fix: add `.squeeze(0).float()` after the vocoder call in the audio decoder step.
|
||||
|
||||
## 27. encode_video float-to-uint8 rounding
|
||||
|
||||
The reference converts float video to uint8 via `.to(torch.uint8)` (truncation), but diffusers' `encode_video` may use `(video * 255).round().astype("uint8")` (rounding). This causes 1 pixel diff per channel at ~50% of pixels. Fix: use truncation (`.astype("uint8")`) to match the reference.
|
||||
8
.github/workflows/release_tests_fast.yml
vendored
8
.github/workflows/release_tests_fast.yml
vendored
@@ -4,6 +4,7 @@
|
||||
name: (Release) Fast GPU Tests on main
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
push:
|
||||
branches:
|
||||
- "v*.*.*-release"
|
||||
@@ -33,6 +34,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -74,6 +76,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -125,6 +128,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -175,6 +179,7 @@ jobs:
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -232,6 +237,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -274,6 +280,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -316,6 +323,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -182,4 +182,6 @@ wandb
|
||||
|
||||
# AI agent generated symlinks
|
||||
/AGENTS.md
|
||||
/CLAUDE.md
|
||||
/CLAUDE.md
|
||||
/.agents/skills
|
||||
/.claude/skills
|
||||
7
Makefile
7
Makefile
@@ -103,9 +103,16 @@ post-patch:
|
||||
|
||||
codex:
|
||||
ln -snf .ai/AGENTS.md AGENTS.md
|
||||
mkdir -p .agents
|
||||
rm -rf .agents/skills
|
||||
ln -snf ../.ai/skills .agents/skills
|
||||
|
||||
claude:
|
||||
ln -snf .ai/AGENTS.md CLAUDE.md
|
||||
mkdir -p .claude
|
||||
rm -rf .claude/skills
|
||||
ln -snf ../.ai/skills .claude/skills
|
||||
|
||||
clean-ai:
|
||||
rm -f AGENTS.md CLAUDE.md
|
||||
rm -rf .agents/skills .claude/skills
|
||||
|
||||
@@ -572,9 +572,9 @@ For documentation strings, 🧨 Diffusers follows the [Google style](https://goo
|
||||
|
||||
The repository keeps AI-agent configuration in `.ai/` and exposes local agent files via symlinks.
|
||||
|
||||
- **Source of truth** — edit `.ai/AGENTS.md` (and any future `.ai/skills/`)
|
||||
- **Don't edit** generated root-level `AGENTS.md` or `CLAUDE.md` — they are symlinks
|
||||
- **Source of truth** — edit files under `.ai/` (`AGENTS.md` for coding guidelines, `skills/` for on-demand task knowledge)
|
||||
- **Don't edit** generated root-level `AGENTS.md`, `CLAUDE.md`, or `.agents/skills`/`.claude/skills` — they are symlinks
|
||||
- Setup commands:
|
||||
- `make codex` — symlink for OpenAI Codex
|
||||
- `make claude` — symlink for Claude Code
|
||||
- `make clean-ai` — remove generated symlinks
|
||||
- `make codex` — symlink guidelines + skills for OpenAI Codex
|
||||
- `make claude` — symlink guidelines + skills for Claude Code
|
||||
- `make clean-ai` — remove all generated symlinks
|
||||
@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
|
||||
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
|
||||
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
|
||||
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
|
||||
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
|
||||
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
|
||||
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
|
||||
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |
|
||||
|
||||
@@ -7,7 +7,7 @@ import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration, Gemma3Processor
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
@@ -17,7 +17,7 @@ from diffusers import (
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder, LTX2VocoderWithBWE
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
@@ -44,6 +44,12 @@ LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
**LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT,
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
@@ -72,6 +78,13 @@ LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_3_VIDEO_VAE_RENAME_DICT = {
|
||||
**LTX_2_0_VIDEO_VAE_RENAME_DICT,
|
||||
# Decoder extra blocks
|
||||
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
||||
"up_blocks.8": "up_blocks.3",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
@@ -84,10 +97,34 @@ LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
LTX_2_3_VOCODER_RENAME_DICT = {
|
||||
# Handle upsamplers ("ups" --> "upsamplers") due to name clash
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
"act_post": "act_out",
|
||||
"downsample.lowpass": "downsample",
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_3_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# LTX-2.3 uses per-modality embedding projections
|
||||
"text_embedding_projection.audio_aggregate_embed": "audio_text_proj_in",
|
||||
"text_embedding_projection.video_aggregate_embed": "video_text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
@@ -129,23 +166,24 @@ def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: dict[str
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_3_vocoder_upsamplers(key: str, state_dict: dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if ".ups." in key:
|
||||
new_key = key.replace(".ups.", ".upsamplers.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
@@ -155,13 +193,19 @@ LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP = {
|
||||
".ups.": convert_ltx2_3_vocoder_upsamplers,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"text_embedding_projection",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
@@ -225,7 +269,7 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
@@ -238,6 +282,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": False,
|
||||
"cross_attn_mod": False,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
@@ -249,6 +295,8 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": False,
|
||||
"audio_cross_attn_mod": False,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
@@ -263,10 +311,62 @@ def get_ltx2_transformer_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": True,
|
||||
"perturbed_attn": False,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"gated_attn": True,
|
||||
"cross_attn_mod": True,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"audio_gated_attn": True,
|
||||
"audio_cross_attn_mod": True,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
"use_prompt_embeddings": False,
|
||||
"perturbed_attn": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -293,7 +393,7 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
@@ -301,20 +401,52 @@ def get_ltx2_connectors_config(version: str) -> tuple[dict[str, Any], dict[str,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": False,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": False,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": False,
|
||||
"proj_bias": False,
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 32,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 8,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"video_gated_attn": True,
|
||||
"audio_connector_num_attention_heads": 32,
|
||||
"audio_connector_attention_head_dim": 64,
|
||||
"audio_connector_num_layers": 8,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"audio_gated_attn": True,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
"per_modality_projections": True,
|
||||
"video_hidden_dim": 4096,
|
||||
"audio_hidden_dim": 2048,
|
||||
"proj_bias": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_CONNECTORS_SPECIAL_KEYS_REMAP
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
@@ -416,7 +548,7 @@ def get_ltx2_video_vae_config(
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
@@ -435,6 +567,7 @@ def get_ltx2_video_vae_config(
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
@@ -451,6 +584,44 @@ def get_ltx2_video_vae_config(
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 1024),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 512, 1024),
|
||||
"layers_per_block": (4, 6, 4, 2, 2),
|
||||
"decoder_layers_per_block": (4, 6, 4, 2, 2),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_type": ("spatiotemporal", "spatiotemporal", "temporal", "spatial"),
|
||||
"upsample_residual": (False, False, False, False),
|
||||
"upsample_factor": (2, 2, 1, 2),
|
||||
"timestep_conditioning": timestep_conditioning,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "zeros",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -485,7 +656,7 @@ def convert_ltx2_video_vae(
|
||||
def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
@@ -508,6 +679,31 @@ def get_ltx2_audio_vae_config(version: str) -> tuple[dict[str, Any], dict[str, A
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
}, # Same config as LTX-2.0
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
@@ -540,7 +736,7 @@ def convert_ltx2_audio_vae(original_state_dict: dict[str, Any], version: str) ->
|
||||
def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"model_id": "Lightricks/LTX-2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
@@ -549,21 +745,71 @@ def get_ltx2_vocoder_config(version: str) -> tuple[dict[str, Any], dict[str, Any
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "leaky_relu",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": False,
|
||||
"final_act_fn": "tanh",
|
||||
"final_bias": True,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"model_id": "Lightricks/LTX-2.3",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1536,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [11, 4, 4, 4, 4, 4],
|
||||
"upsample_factors": [5, 2, 2, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"act_fn": "snakebeta",
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"antialias": True,
|
||||
"antialias_ratio": 2,
|
||||
"antialias_kernel_size": 12,
|
||||
"final_act_fn": None,
|
||||
"final_bias": False,
|
||||
"bwe_in_channels": 128,
|
||||
"bwe_hidden_channels": 512,
|
||||
"bwe_out_channels": 2,
|
||||
"bwe_upsample_kernel_sizes": [12, 11, 4, 4, 4],
|
||||
"bwe_upsample_factors": [6, 5, 2, 2, 2],
|
||||
"bwe_resnet_kernel_sizes": [3, 7, 11],
|
||||
"bwe_resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"bwe_act_fn": "snakebeta",
|
||||
"bwe_leaky_relu_negative_slope": 0.1,
|
||||
"bwe_antialias": True,
|
||||
"bwe_antialias_ratio": 2,
|
||||
"bwe_antialias_kernel_size": 12,
|
||||
"bwe_final_act_fn": None,
|
||||
"bwe_final_bias": False,
|
||||
"filter_length": 512,
|
||||
"hop_length": 80,
|
||||
"window_length": 512,
|
||||
"num_mel_channels": 64,
|
||||
"input_sampling_rate": 16000,
|
||||
"output_sampling_rate": 48000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_3_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_3_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: dict[str, Any], version: str) -> dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
if version == "2.3":
|
||||
vocoder_cls = LTX2VocoderWithBWE
|
||||
else:
|
||||
vocoder_cls = LTX2Vocoder
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
vocoder = vocoder_cls.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -594,6 +840,18 @@ def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": True,
|
||||
}
|
||||
elif version == "2.3":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
"use_rational_resampler": False,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
@@ -651,13 +909,17 @@ def get_model_state_dict_from_combined_ckpt(combined_ckpt: dict[str, Any], prefi
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
model_state_dict[param_name.removeprefix(prefix)] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
connector_prefixes = ["text_embedding_projection"]
|
||||
for param_name, param in combined_ckpt.items():
|
||||
for prefix in connector_prefixes:
|
||||
if param_name.startswith(prefix):
|
||||
# Check to make sure we're not overwriting an existing key
|
||||
if param_name not in model_state_dict:
|
||||
model_state_dict[param_name] = combined_ckpt[param_name]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
@@ -686,7 +948,7 @@ def get_args():
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
choices=["test", "2.0", "2.3"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
@@ -748,6 +1010,11 @@ def get_args():
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_processor",
|
||||
action="store_true",
|
||||
help="Whether to add a Gemma3Processor to the pipeline for prompt enhancement.",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
@@ -756,6 +1023,12 @@ def get_args():
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument(
|
||||
"--upsample_output_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path where converted upsampling pipeline should be saved",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
@@ -787,7 +1060,7 @@ def main(args):
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.connectors,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
@@ -852,7 +1125,12 @@ def main(args):
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.add_processor:
|
||||
processor = Gemma3Processor.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
processor.save_pretrained(os.path.join(args.output_path, "processor"))
|
||||
|
||||
if args.latent_upsampler or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
@@ -866,14 +1144,26 @@ def main(args):
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
is_distilled_ckpt = "distilled" in args.combined_filename
|
||||
if is_distilled_ckpt:
|
||||
# Disable dynamic shifting and terminal shift so that distilled sigmas are used as-is
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=False,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=None,
|
||||
)
|
||||
else:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
@@ -891,10 +1181,12 @@ def main(args):
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
# As two diffusers pipelines cannot be in the same directory, save the upsampling pipeline to its own directory
|
||||
if args.upsample_output_path:
|
||||
upsample_output_path = args.upsample_output_path
|
||||
else:
|
||||
upsample_output_path = args.output_path
|
||||
pipe.save_pretrained(upsample_output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -12,6 +12,7 @@ from termcolor import colored
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLWan,
|
||||
DPMSolverMultistepScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
@@ -24,7 +25,10 @@ from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available else nullcontext
|
||||
|
||||
ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
|
||||
ckpt_ids = [
|
||||
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
|
||||
"Efficient-Large-Model/SANA-Video_2B_720p/checkpoints/SANA_Video_2B_720p_LTXVAE.pth",
|
||||
]
|
||||
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py
|
||||
|
||||
|
||||
@@ -92,12 +96,22 @@ def main(args):
|
||||
if args.video_size == 480:
|
||||
sample_size = 30 # Wan-VAE: 8xp2 downsample factor
|
||||
patch_size = (1, 2, 2)
|
||||
in_channels = 16
|
||||
out_channels = 16
|
||||
elif args.video_size == 720:
|
||||
sample_size = 22 # Wan-VAE: 32xp1 downsample factor
|
||||
sample_size = 22 # DC-AE-V: 32xp1 downsample factor
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 32
|
||||
out_channels = 32
|
||||
else:
|
||||
raise ValueError(f"Video size {args.video_size} is not supported.")
|
||||
|
||||
if args.vae_type == "ltx2":
|
||||
sample_size = 22
|
||||
patch_size = (1, 1, 1)
|
||||
in_channels = 128
|
||||
out_channels = 128
|
||||
|
||||
for depth in range(layer_num):
|
||||
# Transformer blocks.
|
||||
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
|
||||
@@ -182,8 +196,8 @@ def main(args):
|
||||
# Transformer
|
||||
with CTX():
|
||||
transformer_kwargs = {
|
||||
"in_channels": 16,
|
||||
"out_channels": 16,
|
||||
"in_channels": in_channels,
|
||||
"out_channels": out_channels,
|
||||
"num_attention_heads": 20,
|
||||
"attention_head_dim": 112,
|
||||
"num_layers": 20,
|
||||
@@ -235,9 +249,12 @@ def main(args):
|
||||
else:
|
||||
print(colored(f"Saving the whole Pipeline containing {args.model_type}", "green", attrs=["bold"]))
|
||||
# VAE
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
if args.vae_type == "ltx2":
|
||||
vae_path = args.vae_path or "Lightricks/LTX-2"
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
else:
|
||||
vae_path = args.vae_path or "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
|
||||
vae = AutoencoderKLWan.from_pretrained(vae_path, subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# Text Encoder
|
||||
text_encoder_model_path = "Efficient-Large-Model/gemma-2-2b-it"
|
||||
@@ -314,7 +331,23 @@ if __name__ == "__main__":
|
||||
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
|
||||
help="Scheduler type to use.",
|
||||
)
|
||||
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
default="wan",
|
||||
type=str,
|
||||
choices=["wan", "ltx2"],
|
||||
help="VAE type to use for saving full pipeline (ltx2 uses patchify 1x1x1).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=False,
|
||||
help="Optional VAE path or repo id. If not set, a default is used per VAE type.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v."
|
||||
)
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
|
||||
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
|
||||
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
|
||||
|
||||
@@ -434,6 +434,9 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
"FluxModularPipeline",
|
||||
"LTX2AutoBlocks",
|
||||
"LTX2Blocks",
|
||||
"LTX2ModularPipeline",
|
||||
"HeliosAutoBlocks",
|
||||
"HeliosModularPipeline",
|
||||
"HeliosPyramidAutoBlocks",
|
||||
@@ -1195,6 +1198,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
LTX2AutoBlocks,
|
||||
LTX2Blocks,
|
||||
LTX2ModularPipeline,
|
||||
HeliosAutoBlocks,
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidAutoBlocks,
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import functools
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Type
|
||||
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
|
||||
gather_size_by_comm,
|
||||
)
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
|
||||
return tensor
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=64)
|
||||
@lru_cache_unless_export(maxsize=64)
|
||||
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
|
||||
gather_shapes = []
|
||||
for i in range(world_size):
|
||||
|
||||
@@ -2156,6 +2156,9 @@ def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
# LTX-2.3
|
||||
"audio_prompt_adaln_single": "audio_prompt_adaln",
|
||||
"prompt_adaln_single": "prompt_adaln",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
|
||||
@@ -49,7 +49,7 @@ from ..utils import (
|
||||
is_xformers_version,
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
from ..utils.torch_utils import maybe_allow_in_graph
|
||||
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from ._modeling_parallel import gather_size_by_comm
|
||||
|
||||
|
||||
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
|
||||
FLASH_HUB = "flash_hub"
|
||||
FLASH_VARLEN = "flash_varlen"
|
||||
FLASH_VARLEN_HUB = "flash_varlen_hub"
|
||||
FLASH_4_HUB = "flash_4_hub"
|
||||
_FLASH_3 = "_flash_3"
|
||||
_FLASH_VARLEN_3 = "_flash_varlen_3"
|
||||
_FLASH_3_HUB = "_flash_3_hub"
|
||||
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
function_attr="sageattn",
|
||||
version=1,
|
||||
),
|
||||
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-staging/flash-attn4",
|
||||
function_attr="flash_attn_func",
|
||||
version=0,
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB,
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
]:
|
||||
if not is_kernels_available():
|
||||
raise RuntimeError(
|
||||
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
|
||||
raise RuntimeError(
|
||||
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
|
||||
)
|
||||
|
||||
elif backend == AttentionBackendName.AITER:
|
||||
if not _CAN_USE_AITER_ATTN:
|
||||
raise RuntimeError(
|
||||
@@ -575,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
|
||||
batch_size: int,
|
||||
seq_len_q: int,
|
||||
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_4_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
)
|
||||
def _flash_attention_4_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
scale: float | None = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
)
|
||||
if isinstance(out, tuple):
|
||||
return (out[0], out[1]) if return_lse else out[0]
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_VARLEN_3,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
|
||||
@@ -237,7 +237,7 @@ class LTX2VideoResnetBlock3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoDownsampler3d, but uses new causal Conv3d
|
||||
class LTXVideoDownsampler3d(nn.Module):
|
||||
class LTX2VideoDownsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
@@ -285,10 +285,11 @@ class LTXVideoDownsampler3d(nn.Module):
|
||||
|
||||
|
||||
# Like LTX 1.0 LTXVideoUpsampler3d, but uses new causal Conv3d
|
||||
class LTXVideoUpsampler3d(nn.Module):
|
||||
class LTX2VideoUpsampler3d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int | None = None,
|
||||
stride: int | tuple[int, int, int] = 1,
|
||||
residual: bool = False,
|
||||
upscale_factor: int = 1,
|
||||
@@ -300,7 +301,8 @@ class LTXVideoUpsampler3d(nn.Module):
|
||||
self.residual = residual
|
||||
self.upscale_factor = upscale_factor
|
||||
|
||||
out_channels = (in_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
out_channels = out_channels or in_channels
|
||||
out_channels = (out_channels * stride[0] * stride[1] * stride[2]) // upscale_factor
|
||||
|
||||
self.conv = LTX2VideoCausalConv3d(
|
||||
in_channels=in_channels,
|
||||
@@ -408,7 +410,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatial":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(1, 2, 2),
|
||||
@@ -417,7 +419,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "temporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 1, 1),
|
||||
@@ -426,7 +428,7 @@ class LTX2VideoDownBlock3D(nn.Module):
|
||||
)
|
||||
elif downsample_type == "spatiotemporal":
|
||||
self.downsamplers.append(
|
||||
LTXVideoDownsampler3d(
|
||||
LTX2VideoDownsampler3d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
stride=(2, 2, 2),
|
||||
@@ -580,6 +582,7 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_act_fn: str = "swish",
|
||||
spatio_temporal_scale: bool = True,
|
||||
upsample_type: str = "spatiotemporal",
|
||||
inject_noise: bool = False,
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: bool = False,
|
||||
@@ -609,16 +612,23 @@ class LTX2VideoUpBlock3d(nn.Module):
|
||||
|
||||
self.upsamplers = None
|
||||
if spatio_temporal_scale:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
LTXVideoUpsampler3d(
|
||||
out_channels * upscale_factor,
|
||||
stride=(2, 2, 2),
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
]
|
||||
self.upsamplers = nn.ModuleList()
|
||||
|
||||
if upsample_type == "spatial":
|
||||
upsample_stride = (1, 2, 2)
|
||||
elif upsample_type == "temporal":
|
||||
upsample_stride = (2, 1, 1)
|
||||
elif upsample_type == "spatiotemporal":
|
||||
upsample_stride = (2, 2, 2)
|
||||
|
||||
self.upsamplers.append(
|
||||
LTX2VideoUpsampler3d(
|
||||
in_channels=out_channels * upscale_factor,
|
||||
stride=upsample_stride,
|
||||
residual=upsample_residual,
|
||||
upscale_factor=upscale_factor,
|
||||
spatial_padding_mode=spatial_padding_mode,
|
||||
)
|
||||
)
|
||||
|
||||
resnets = []
|
||||
@@ -716,7 +726,7 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
@@ -726,6 +736,9 @@ class LTX2VideoEncoder3d(nn.Module):
|
||||
spatial_padding_mode: str = "zeros",
|
||||
):
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -860,19 +873,27 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
in_channels: int = 128,
|
||||
out_channels: int = 3,
|
||||
block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
patch_size: int = 4,
|
||||
patch_size_t: int = 1,
|
||||
resnet_norm_eps: float = 1e-6,
|
||||
is_causal: bool = False,
|
||||
inject_noise: tuple[bool, ...] = (False, False, False),
|
||||
inject_noise: bool | tuple[bool, ...] = (False, False, False),
|
||||
timestep_conditioning: bool = False,
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[bool, ...] = (2, 2, 2),
|
||||
spatial_padding_mode: str = "reflect",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_decoder_blocks = len(layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(inject_noise, bool):
|
||||
inject_noise = (inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.patch_size_t = patch_size_t
|
||||
@@ -917,6 +938,7 @@ class LTX2VideoDecoder3d(nn.Module):
|
||||
num_layers=layers_per_block[i + 1],
|
||||
resnet_eps=resnet_norm_eps,
|
||||
spatio_temporal_scale=spatio_temporal_scaling[i],
|
||||
upsample_type=upsample_type[i],
|
||||
inject_noise=inject_noise[i + 1],
|
||||
timestep_conditioning=timestep_conditioning,
|
||||
upsample_residual=upsample_residual[i],
|
||||
@@ -1058,11 +1080,12 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
decoder_block_out_channels: tuple[int, ...] = (256, 512, 1024),
|
||||
layers_per_block: tuple[int, ...] = (4, 6, 6, 2, 2),
|
||||
decoder_layers_per_block: tuple[int, ...] = (5, 5, 5, 5),
|
||||
spatio_temporal_scaling: tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: tuple[bool, ...] = (False, False, False, False),
|
||||
spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True, True),
|
||||
decoder_spatio_temporal_scaling: bool | tuple[bool, ...] = (True, True, True),
|
||||
decoder_inject_noise: bool | tuple[bool, ...] = (False, False, False, False),
|
||||
downsample_type: tuple[str, ...] = ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_residual: tuple[bool, ...] = (True, True, True),
|
||||
upsample_type: tuple[str, ...] = ("spatiotemporal", "spatiotemporal", "spatiotemporal"),
|
||||
upsample_residual: bool | tuple[bool, ...] = (True, True, True),
|
||||
upsample_factor: tuple[int, ...] = (2, 2, 2),
|
||||
timestep_conditioning: bool = False,
|
||||
patch_size: int = 4,
|
||||
@@ -1077,6 +1100,16 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
temporal_compression_ratio: int = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
num_encoder_blocks = len(layers_per_block)
|
||||
num_decoder_blocks = len(decoder_layers_per_block)
|
||||
if isinstance(spatio_temporal_scaling, bool):
|
||||
spatio_temporal_scaling = (spatio_temporal_scaling,) * (num_encoder_blocks - 1)
|
||||
if isinstance(decoder_spatio_temporal_scaling, bool):
|
||||
decoder_spatio_temporal_scaling = (decoder_spatio_temporal_scaling,) * (num_decoder_blocks - 1)
|
||||
if isinstance(decoder_inject_noise, bool):
|
||||
decoder_inject_noise = (decoder_inject_noise,) * num_decoder_blocks
|
||||
if isinstance(upsample_residual, bool):
|
||||
upsample_residual = (upsample_residual,) * (num_decoder_blocks - 1)
|
||||
|
||||
self.encoder = LTX2VideoEncoder3d(
|
||||
in_channels=in_channels,
|
||||
@@ -1098,6 +1131,7 @@ class AutoencoderKLLTX2Video(ModelMixin, AutoencoderMixin, ConfigMixin, FromOrig
|
||||
block_out_channels=decoder_block_out_channels,
|
||||
spatio_temporal_scaling=decoder_spatio_temporal_scaling,
|
||||
layers_per_block=decoder_layers_per_block,
|
||||
upsample_type=upsample_type,
|
||||
patch_size=patch_size,
|
||||
patch_size_t=patch_size_t,
|
||||
resnet_norm_eps=resnet_norm_eps,
|
||||
|
||||
@@ -550,19 +550,9 @@ class RMSNorm(nn.Module):
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
hidden_states = torch.nn.functional.rms_norm(hidden_states, self.dim, self.weight, self.eps)
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
@@ -343,7 +342,6 @@ class HeliosRotaryPosEmbed(nn.Module):
|
||||
return freqs.cos(), freqs.sin()
|
||||
|
||||
@torch.no_grad()
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_spatial_meshgrid(self, height, width, device_str):
|
||||
device = torch.device(device_str)
|
||||
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
|
||||
|
||||
@@ -37,16 +37,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
cos, sin = cos.to(x.dtype), sin.to(x.dtype)
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
out = x * cos + x_rotated * sin
|
||||
return out
|
||||
|
||||
|
||||
def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
|
||||
x_dtype = x.dtype
|
||||
needs_reshape = False
|
||||
if x.ndim != 4 and cos.ndim == 4:
|
||||
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
|
||||
@@ -61,12 +61,12 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Ten
|
||||
r = last // 2
|
||||
|
||||
# (..., 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r)
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
second_x = split_x[..., 1:, :] # (..., 1, r)
|
||||
|
||||
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
|
||||
sin_u = sin.unsqueeze(-2)
|
||||
cos_u = cos.to(x.dtype).unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
|
||||
sin_u = sin.to(x.dtype).unsqueeze(-2)
|
||||
|
||||
out = split_x * cos_u
|
||||
first_out = out[..., :1, :]
|
||||
@@ -80,7 +80,6 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Ten
|
||||
if needs_reshape:
|
||||
out = out.swapaxes(1, 2).reshape(b, t, -1)
|
||||
|
||||
out = out.to(dtype=x_dtype)
|
||||
return out
|
||||
|
||||
|
||||
@@ -178,6 +177,10 @@ class LTX2AudioVideoAttnProcessor:
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -212,6 +215,112 @@ class LTX2AudioVideoAttnProcessor:
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2PerturbedAttnProcessor:
|
||||
r"""
|
||||
Processor which implements attention with perturbation masking and per-head gating for LTX-2.X models.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if is_torch_version("<", "2.0"):
|
||||
raise ValueError(
|
||||
"LTX attention processors require a minimum PyTorch version of 2.0. Please upgrade your PyTorch installation."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "LTX2Attention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
query_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
key_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
# Calculate gate logits on original hidden_states
|
||||
gate_logits = attn.to_gate_logits(hidden_states)
|
||||
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
if all_perturbed is None:
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
|
||||
if all_perturbed:
|
||||
# Skip attention, use the value projection value
|
||||
hidden_states = value
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
if query_rotary_emb is not None:
|
||||
if attn.rope_type == "interleaved":
|
||||
query = apply_interleaved_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_interleaved_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
elif attn.rope_type == "split":
|
||||
query = apply_split_rotary_emb(query, query_rotary_emb)
|
||||
key = apply_split_rotary_emb(
|
||||
key, key_rotary_emb if key_rotary_emb is not None else query_rotary_emb
|
||||
)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
if perturbation_mask is not None:
|
||||
value = value.flatten(2, 3)
|
||||
hidden_states = torch.lerp(value, hidden_states, perturbation_mask)
|
||||
|
||||
if attn.to_gate_logits is not None:
|
||||
hidden_states = hidden_states.unflatten(2, (attn.heads, -1)) # [B, T, H, D]
|
||||
# The factor of 2.0 is so that if the gates logits are zero-initialized the initial gates are all 1
|
||||
gates = 2.0 * torch.sigmoid(gate_logits) # [B, T, H]
|
||||
hidden_states = hidden_states * gates.unsqueeze(-1)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
@@ -224,7 +333,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
"""
|
||||
|
||||
_default_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor]
|
||||
_available_processors = [LTX2AudioVideoAttnProcessor, LTX2PerturbedAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -240,6 +349,7 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
norm_eps: float = 1e-6,
|
||||
norm_elementwise_affine: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -266,6 +376,12 @@ class LTX2Attention(torch.nn.Module, AttentionModuleMixin):
|
||||
self.to_out.append(torch.nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
|
||||
self.to_out.append(torch.nn.Dropout(dropout))
|
||||
|
||||
if apply_gated_attention:
|
||||
# Per head gate values
|
||||
self.to_gate_logits = torch.nn.Linear(query_dim, heads, bias=True)
|
||||
else:
|
||||
self.to_gate_logits = None
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
@@ -321,6 +437,10 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
audio_num_attention_heads: int,
|
||||
audio_attention_head_dim,
|
||||
audio_cross_attention_dim: int,
|
||||
video_gated_attn: bool = False,
|
||||
video_cross_attn_adaln: bool = False,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_adaln: bool = False,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
activation_fn: str = "gelu-approximate",
|
||||
attention_bias: bool = True,
|
||||
@@ -328,9 +448,16 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
elementwise_affine: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
perturbed_attn: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.perturbed_attn = perturbed_attn
|
||||
if perturbed_attn:
|
||||
attn_processor_cls = LTX2PerturbedAttnProcessor
|
||||
else:
|
||||
attn_processor_cls = LTX2AudioVideoAttnProcessor
|
||||
|
||||
# 1. Self-Attention (video and audio)
|
||||
self.norm1 = RMSNorm(dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
self.attn1 = LTX2Attention(
|
||||
@@ -343,6 +470,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm1 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -356,6 +485,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 2. Prompt Cross-Attention
|
||||
@@ -370,6 +501,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
self.audio_norm2 = RMSNorm(audio_dim, eps=eps, elementwise_affine=elementwise_affine)
|
||||
@@ -383,6 +516,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
@@ -398,6 +533,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=video_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# Video-to-Audio (v2a) Attention --> Q: Audio; K,V: Video
|
||||
@@ -412,6 +549,8 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
out_bias=attention_out_bias,
|
||||
qk_norm=qk_norm,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=audio_gated_attn,
|
||||
processor=attn_processor_cls(),
|
||||
)
|
||||
|
||||
# 4. Feedforward layers
|
||||
@@ -422,14 +561,36 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
self.audio_ff = FeedForward(audio_dim, activation_fn=activation_fn)
|
||||
|
||||
# 5. Per-Layer Modulation Parameters
|
||||
# Self-Attention / Feedforward AdaLayerNorm-Zero mod params
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(6, audio_dim) / audio_dim**0.5)
|
||||
# Self-Attention (attn1) / Feedforward AdaLayerNorm-Zero mod params
|
||||
# 6 base mod params for text cross-attn K,V; if cross_attn_adaln, also has mod params for Q
|
||||
self.video_cross_attn_adaln = video_cross_attn_adaln
|
||||
self.audio_cross_attn_adaln = audio_cross_attn_adaln
|
||||
video_mod_param_num = 9 if self.video_cross_attn_adaln else 6
|
||||
audio_mod_param_num = 9 if self.audio_cross_attn_adaln else 6
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(video_mod_param_num, dim) / dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(audio_mod_param_num, audio_dim) / audio_dim**0.5)
|
||||
|
||||
# Prompt cross-attn (attn2) additional modulation params
|
||||
self.cross_attn_adaln = video_cross_attn_adaln or audio_cross_attn_adaln
|
||||
if self.cross_attn_adaln:
|
||||
self.prompt_scale_shift_table = nn.Parameter(torch.randn(2, dim))
|
||||
self.audio_prompt_scale_shift_table = nn.Parameter(torch.randn(2, audio_dim))
|
||||
|
||||
# Per-layer a2v, v2a Cross-Attention mod params
|
||||
self.video_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, dim))
|
||||
self.audio_a2v_cross_attn_scale_shift_table = nn.Parameter(torch.randn(5, audio_dim))
|
||||
|
||||
@staticmethod
|
||||
def get_mod_params(
|
||||
scale_shift_table: torch.Tensor, temb: torch.Tensor, batch_size: int
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
num_ada_params = scale_shift_table.shape[0]
|
||||
ada_values = scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.shape[1], num_ada_params, -1
|
||||
)
|
||||
ada_params = ada_values.unbind(dim=2)
|
||||
return ada_params
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -442,143 +603,181 @@ class LTX2VideoTransformerBlock(nn.Module):
|
||||
temb_ca_audio_scale_shift: torch.Tensor,
|
||||
temb_ca_gate: torch.Tensor,
|
||||
temb_ca_audio_gate: torch.Tensor,
|
||||
temb_prompt: torch.Tensor | None = None,
|
||||
temb_prompt_audio: torch.Tensor | None = None,
|
||||
video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_video_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
ca_audio_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
self_attention_mask: torch.Tensor | None = None,
|
||||
audio_self_attention_mask: torch.Tensor | None = None,
|
||||
a2v_cross_attention_mask: torch.Tensor | None = None,
|
||||
v2a_cross_attention_mask: torch.Tensor | None = None,
|
||||
use_a2v_cross_attention: bool = True,
|
||||
use_v2a_cross_attention: bool = True,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
all_perturbed: bool | None = None,
|
||||
) -> torch.Tensor:
|
||||
batch_size = hidden_states.size(0)
|
||||
|
||||
# 1. Video and Audio Self-Attention
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
# 1.1. Video Self-Attention
|
||||
video_ada_params = self.get_mod_params(self.scale_shift_table, temb, batch_size)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = video_ada_params[:6]
|
||||
if self.video_cross_attn_adaln:
|
||||
shift_text_q, scale_text_q, gate_text_q = video_ada_params[6:9]
|
||||
|
||||
num_ada_params = self.scale_shift_table.shape[0]
|
||||
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
|
||||
batch_size, temb.size(1), num_ada_params, -1
|
||||
)
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
|
||||
|
||||
attn_hidden_states = self.attn1(
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=video_rotary_emb,
|
||||
)
|
||||
video_self_attn_args = {
|
||||
"hidden_states": norm_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": video_rotary_emb,
|
||||
"attention_mask": self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
video_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
video_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_hidden_states = self.attn1(**video_self_attn_args)
|
||||
hidden_states = hidden_states + attn_hidden_states * gate_msa
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
|
||||
num_audio_ada_params = self.audio_scale_shift_table.shape[0]
|
||||
audio_ada_values = self.audio_scale_shift_table[None, None].to(temb_audio.device) + temb_audio.reshape(
|
||||
batch_size, temb_audio.size(1), num_audio_ada_params, -1
|
||||
)
|
||||
# 1.2. Audio Self-Attention
|
||||
audio_ada_params = self.get_mod_params(self.audio_scale_shift_table, temb_audio, batch_size)
|
||||
audio_shift_msa, audio_scale_msa, audio_gate_msa, audio_shift_mlp, audio_scale_mlp, audio_gate_mlp = (
|
||||
audio_ada_values.unbind(dim=2)
|
||||
audio_ada_params[:6]
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
audio_shift_text_q, audio_scale_text_q, audio_gate_text_q = audio_ada_params[6:9]
|
||||
|
||||
norm_audio_hidden_states = self.audio_norm1(audio_hidden_states)
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_msa) + audio_shift_msa
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(
|
||||
hidden_states=norm_audio_hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
query_rotary_emb=audio_rotary_emb,
|
||||
)
|
||||
audio_self_attn_args = {
|
||||
"hidden_states": norm_audio_hidden_states,
|
||||
"encoder_hidden_states": None,
|
||||
"query_rotary_emb": audio_rotary_emb,
|
||||
"attention_mask": audio_self_attention_mask,
|
||||
}
|
||||
if self.perturbed_attn:
|
||||
audio_self_attn_args["perturbation_mask"] = perturbation_mask
|
||||
audio_self_attn_args["all_perturbed"] = all_perturbed
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn1(**audio_self_attn_args)
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states * audio_gate_msa
|
||||
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings
|
||||
# 2. Video and Audio Cross-Attention with the text embeddings (Q: Video or Audio; K,V: Text)
|
||||
if self.cross_attn_adaln:
|
||||
video_prompt_ada_params = self.get_mod_params(self.prompt_scale_shift_table, temb_prompt, batch_size)
|
||||
shift_text_kv, scale_text_kv = video_prompt_ada_params
|
||||
|
||||
audio_prompt_ada_params = self.get_mod_params(
|
||||
self.audio_prompt_scale_shift_table, temb_prompt_audio, batch_size
|
||||
)
|
||||
audio_shift_text_kv, audio_scale_text_kv = audio_prompt_ada_params
|
||||
|
||||
# 2.1. Video-Text Cross-Attention (Q: Video; K,V: Text)
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
if self.video_cross_attn_adaln:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_text_q) + shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
encoder_hidden_states = encoder_hidden_states * (1 + scale_text_kv) + shift_text_kv
|
||||
|
||||
attn_hidden_states = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=encoder_attention_mask,
|
||||
)
|
||||
if self.video_cross_attn_adaln:
|
||||
attn_hidden_states = attn_hidden_states * gate_text_q
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
# 2.2. Audio-Text Cross-Attention
|
||||
norm_audio_hidden_states = self.audio_norm2(audio_hidden_states)
|
||||
if self.audio_cross_attn_adaln:
|
||||
norm_audio_hidden_states = norm_audio_hidden_states * (1 + audio_scale_text_q) + audio_shift_text_q
|
||||
if self.cross_attn_adaln:
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states * (1 + audio_scale_text_kv) + audio_shift_text_kv
|
||||
|
||||
attn_audio_hidden_states = self.audio_attn2(
|
||||
norm_audio_hidden_states,
|
||||
encoder_hidden_states=audio_encoder_hidden_states,
|
||||
query_rotary_emb=None,
|
||||
attention_mask=audio_encoder_attention_mask,
|
||||
)
|
||||
if self.audio_cross_attn_adaln:
|
||||
attn_audio_hidden_states = attn_audio_hidden_states * audio_gate_text_q
|
||||
audio_hidden_states = audio_hidden_states + attn_audio_hidden_states
|
||||
|
||||
# 3. Audio-to-Video (a2v) and Video-to-Audio (v2a) Cross-Attention
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
if use_a2v_cross_attention or use_v2a_cross_attention:
|
||||
norm_hidden_states = self.audio_to_video_norm(hidden_states)
|
||||
norm_audio_hidden_states = self.video_to_audio_norm(audio_hidden_states)
|
||||
|
||||
# Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# 3.1. Combine global and per-layer cross attention modulation parameters
|
||||
# Video
|
||||
video_per_layer_ca_scale_shift = self.video_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
video_per_layer_ca_gate = self.video_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
video_ca_scale_shift_table = (
|
||||
video_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_scale_shift.dtype)
|
||||
+ temb_ca_scale_shift.reshape(batch_size, temb_ca_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_gate = (
|
||||
video_per_layer_ca_gate[:, :, ...].to(temb_ca_gate.dtype)
|
||||
+ temb_ca_gate.reshape(batch_size, temb_ca_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
video_ca_ada_params = self.get_mod_params(video_per_layer_ca_scale_shift, temb_ca_scale_shift, batch_size)
|
||||
video_ca_gate_param = self.get_mod_params(video_per_layer_ca_gate, temb_ca_gate, batch_size)
|
||||
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_scale_shift_table
|
||||
a2v_gate = video_ca_gate[0].squeeze(2)
|
||||
video_a2v_ca_scale, video_a2v_ca_shift, video_v2a_ca_scale, video_v2a_ca_shift = video_ca_ada_params
|
||||
a2v_gate = video_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
# Audio
|
||||
audio_per_layer_ca_scale_shift = self.audio_a2v_cross_attn_scale_shift_table[:4, :]
|
||||
audio_per_layer_ca_gate = self.audio_a2v_cross_attn_scale_shift_table[4:, :]
|
||||
|
||||
audio_ca_scale_shift_table = (
|
||||
audio_per_layer_ca_scale_shift[:, :, ...].to(temb_ca_audio_scale_shift.dtype)
|
||||
+ temb_ca_audio_scale_shift.reshape(batch_size, temb_ca_audio_scale_shift.shape[1], 4, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_gate = (
|
||||
audio_per_layer_ca_gate[:, :, ...].to(temb_ca_audio_gate.dtype)
|
||||
+ temb_ca_audio_gate.reshape(batch_size, temb_ca_audio_gate.shape[1], 1, -1)
|
||||
).unbind(dim=2)
|
||||
audio_ca_ada_params = self.get_mod_params(
|
||||
audio_per_layer_ca_scale_shift, temb_ca_audio_scale_shift, batch_size
|
||||
)
|
||||
audio_ca_gate_param = self.get_mod_params(audio_per_layer_ca_gate, temb_ca_audio_gate, batch_size)
|
||||
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_scale_shift_table
|
||||
v2a_gate = audio_ca_gate[0].squeeze(2)
|
||||
audio_a2v_ca_scale, audio_a2v_ca_shift, audio_v2a_ca_scale, audio_v2a_ca_shift = audio_ca_ada_params
|
||||
v2a_gate = audio_ca_gate_param[0].squeeze(2)
|
||||
|
||||
# Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_a2v_ca_scale.squeeze(2)) + video_a2v_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
# 3.2. Audio-to-Video Cross Attention: Q: Video; K,V: Audio
|
||||
if use_a2v_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_a2v_ca_scale.squeeze(2)
|
||||
) + video_a2v_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_a2v_ca_scale.squeeze(2)
|
||||
) + audio_a2v_ca_shift.squeeze(2)
|
||||
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
a2v_attn_hidden_states = self.audio_to_video_attn(
|
||||
mod_norm_hidden_states,
|
||||
encoder_hidden_states=mod_norm_audio_hidden_states,
|
||||
query_rotary_emb=ca_video_rotary_emb,
|
||||
key_rotary_emb=ca_audio_rotary_emb,
|
||||
attention_mask=a2v_cross_attention_mask,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
hidden_states = hidden_states + a2v_gate * a2v_attn_hidden_states
|
||||
|
||||
# Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
mod_norm_hidden_states = norm_hidden_states * (1 + video_v2a_ca_scale.squeeze(2)) + video_v2a_ca_shift.squeeze(
|
||||
2
|
||||
)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
# 3.3. Video-to-Audio Cross Attention: Q: Audio; K,V: Video
|
||||
if use_v2a_cross_attention:
|
||||
mod_norm_hidden_states = norm_hidden_states * (
|
||||
1 + video_v2a_ca_scale.squeeze(2)
|
||||
) + video_v2a_ca_shift.squeeze(2)
|
||||
mod_norm_audio_hidden_states = norm_audio_hidden_states * (
|
||||
1 + audio_v2a_ca_scale.squeeze(2)
|
||||
) + audio_v2a_ca_shift.squeeze(2)
|
||||
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
v2a_attn_hidden_states = self.video_to_audio_attn(
|
||||
mod_norm_audio_hidden_states,
|
||||
encoder_hidden_states=mod_norm_hidden_states,
|
||||
query_rotary_emb=ca_audio_rotary_emb,
|
||||
key_rotary_emb=ca_video_rotary_emb,
|
||||
attention_mask=v2a_cross_attention_mask,
|
||||
)
|
||||
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
audio_hidden_states = audio_hidden_states + v2a_gate * v2a_attn_hidden_states
|
||||
|
||||
# 4. Feedforward
|
||||
norm_hidden_states = self.norm3(hidden_states) * (1 + scale_mlp) + shift_mlp
|
||||
@@ -918,6 +1117,8 @@ class LTX2VideoTransformer3DModel(
|
||||
pos_embed_max_pos: int = 20,
|
||||
base_height: int = 2048,
|
||||
base_width: int = 2048,
|
||||
gated_attn: bool = False,
|
||||
cross_attn_mod: bool = False,
|
||||
audio_in_channels: int = 128, # Audio Arguments
|
||||
audio_out_channels: int | None = 128,
|
||||
audio_patch_size: int = 1,
|
||||
@@ -929,6 +1130,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_pos_embed_max_pos: int = 20,
|
||||
audio_sampling_rate: int = 16000,
|
||||
audio_hop_length: int = 160,
|
||||
audio_gated_attn: bool = False,
|
||||
audio_cross_attn_mod: bool = False,
|
||||
num_layers: int = 48, # Shared arguments
|
||||
activation_fn: str = "gelu-approximate",
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
@@ -943,6 +1146,8 @@ class LTX2VideoTransformer3DModel(
|
||||
timestep_scale_multiplier: int = 1000,
|
||||
cross_attn_timestep_scale_multiplier: int = 1000,
|
||||
rope_type: str = "interleaved",
|
||||
use_prompt_embeddings=True,
|
||||
perturbed_attn: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -956,17 +1161,25 @@ class LTX2VideoTransformer3DModel(
|
||||
self.audio_proj_in = nn.Linear(audio_in_channels, audio_inner_dim)
|
||||
|
||||
# 2. Prompt embeddings
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
if use_prompt_embeddings:
|
||||
# LTX-2.0; LTX-2.3 uses per-modality feature projections in the connector instead
|
||||
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
|
||||
self.audio_caption_projection = PixArtAlphaTextProjection(
|
||||
in_features=caption_channels, hidden_size=audio_inner_dim
|
||||
)
|
||||
|
||||
# 3. Timestep Modulation Params and Embedding
|
||||
self.prompt_modulation = cross_attn_mod or audio_cross_attn_mod # used by LTX-2.3
|
||||
|
||||
# 3.1. Global Timestep Modulation Parameters (except for cross-attention) and timestep + size embedding
|
||||
# time_embed and audio_time_embed calculate both the timestep embedding and (global) modulation parameters
|
||||
self.time_embed = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=6, use_additional_conditions=False)
|
||||
video_time_emb_mod_params = 9 if cross_attn_mod else 6
|
||||
audio_time_emb_mod_params = 9 if audio_cross_attn_mod else 6
|
||||
self.time_embed = LTX2AdaLayerNormSingle(
|
||||
inner_dim, num_mod_params=video_time_emb_mod_params, use_additional_conditions=False
|
||||
)
|
||||
self.audio_time_embed = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=6, use_additional_conditions=False
|
||||
audio_inner_dim, num_mod_params=audio_time_emb_mod_params, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 3.2. Global Cross Attention Modulation Parameters
|
||||
@@ -995,6 +1208,13 @@ class LTX2VideoTransformer3DModel(
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
|
||||
self.audio_scale_shift_table = nn.Parameter(torch.randn(2, audio_inner_dim) / audio_inner_dim**0.5)
|
||||
|
||||
# 3.4. Prompt Scale/Shift Modulation parameters (LTX-2.3)
|
||||
if self.prompt_modulation:
|
||||
self.prompt_adaln = LTX2AdaLayerNormSingle(inner_dim, num_mod_params=2, use_additional_conditions=False)
|
||||
self.audio_prompt_adaln = LTX2AdaLayerNormSingle(
|
||||
audio_inner_dim, num_mod_params=2, use_additional_conditions=False
|
||||
)
|
||||
|
||||
# 4. Rotary Positional Embeddings (RoPE)
|
||||
# Self-Attention
|
||||
self.rope = LTX2AudioVideoRotaryPosEmbed(
|
||||
@@ -1071,6 +1291,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_attention_heads=audio_num_attention_heads,
|
||||
audio_attention_head_dim=audio_attention_head_dim,
|
||||
audio_cross_attention_dim=audio_cross_attention_dim,
|
||||
video_gated_attn=gated_attn,
|
||||
video_cross_attn_adaln=cross_attn_mod,
|
||||
audio_gated_attn=audio_gated_attn,
|
||||
audio_cross_attn_adaln=audio_cross_attn_mod,
|
||||
qk_norm=qk_norm,
|
||||
activation_fn=activation_fn,
|
||||
attention_bias=attention_bias,
|
||||
@@ -1078,6 +1302,7 @@ class LTX2VideoTransformer3DModel(
|
||||
eps=norm_eps,
|
||||
elementwise_affine=norm_elementwise_affine,
|
||||
rope_type=rope_type,
|
||||
perturbed_attn=perturbed_attn,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -1101,6 +1326,8 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_encoder_hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
audio_timestep: torch.LongTensor | None = None,
|
||||
sigma: torch.Tensor | None = None,
|
||||
audio_sigma: torch.Tensor | None = None,
|
||||
encoder_attention_mask: torch.Tensor | None = None,
|
||||
audio_encoder_attention_mask: torch.Tensor | None = None,
|
||||
num_frames: int | None = None,
|
||||
@@ -1110,6 +1337,10 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_num_frames: int | None = None,
|
||||
video_coords: torch.Tensor | None = None,
|
||||
audio_coords: torch.Tensor | None = None,
|
||||
isolate_modalities: bool = False,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
perturbation_mask: torch.Tensor | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> torch.Tensor:
|
||||
@@ -1131,6 +1362,13 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_timestep (`torch.Tensor`, *optional*):
|
||||
Input timestep of shape `(batch_size,)` or `(batch_size, num_audio_tokens)` for audio modulation
|
||||
params. This is only used by certain pipelines such as the I2V pipeline.
|
||||
sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for video prompt cross attention modulation in
|
||||
models such as LTX-2.3.
|
||||
audio_sigma (`torch.Tensor`, *optional*):
|
||||
Input scaled timestep of shape (batch_size,). Used for audio prompt cross attention modulation in
|
||||
models such as LTX-2.3. If `sigma` is supplied but `audio_sigma` is not, `audio_sigma` will be set to
|
||||
the provided `sigma` value.
|
||||
encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
Optional multiplicative text attention mask of shape `(batch_size, text_seq_len)`.
|
||||
audio_encoder_attention_mask (`torch.Tensor`, *optional*):
|
||||
@@ -1152,6 +1390,21 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_coords (`torch.Tensor`, *optional*):
|
||||
The audio coordinates to be used when calculating the rotary positional embeddings (RoPE) of shape
|
||||
`(batch_size, 1, num_audio_tokens, 2)`. If not supplied, this will be calculated inside `forward`.
|
||||
isolate_modalities (`bool`, *optional*, defaults to `False`):
|
||||
Whether to isolate each modality by turning off cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention (for all blocks). Use for modality guidance in LTX-2.3.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The transformer block indices at which to apply spatio-temporal guidance (STG), which shortcuts the
|
||||
self-attention operations by simply using the values rather than the full scaled dot-product attention
|
||||
(SDPA) operation. If `None` or empty, STG will not be applied to any block.
|
||||
perturbation_mask (`torch.Tensor`, *optional*):
|
||||
Perturbation mask for STG of shape `(batch_size,)` or `(batch_size, 1, 1)`. Should be 0 at batch
|
||||
elements where STG should be applied and 1 elsewhere. If STG is being used but `peturbation_mask` is
|
||||
not supplied, will default to applying STG (perturbing) all batch elements.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
attention_kwargs (`dict[str, Any]`, *optional*):
|
||||
Optional dict of keyword args to be passed to the attention processor.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
@@ -1165,6 +1418,7 @@ class LTX2VideoTransformer3DModel(
|
||||
"""
|
||||
# Determine timestep for audio.
|
||||
audio_timestep = audio_timestep if audio_timestep is not None else timestep
|
||||
audio_sigma = audio_sigma if audio_sigma is not None else sigma
|
||||
|
||||
# convert encoder_attention_mask to a bias the same way we do for attention_mask
|
||||
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
|
||||
@@ -1223,14 +1477,30 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_audio = temb_audio.view(batch_size, -1, temb_audio.size(-1))
|
||||
audio_embedded_timestep = audio_embedded_timestep.view(batch_size, -1, audio_embedded_timestep.size(-1))
|
||||
|
||||
if self.prompt_modulation:
|
||||
# LTX-2.3
|
||||
temb_prompt, _ = self.prompt_adaln(
|
||||
sigma.flatten(), batch_size=batch_size, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
temb_prompt_audio, _ = self.audio_prompt_adaln(
|
||||
audio_sigma.flatten(), batch_size=batch_size, hidden_dtype=audio_hidden_states.dtype
|
||||
)
|
||||
temb_prompt = temb_prompt.view(batch_size, -1, temb_prompt.size(-1))
|
||||
temb_prompt_audio = temb_prompt_audio.view(batch_size, -1, temb_prompt_audio.size(-1))
|
||||
else:
|
||||
temb_prompt = temb_prompt_audio = None
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
# Reference always uses cross-modality sigma for cross-attention timestep:
|
||||
# video cross-attn uses audio_sigma, audio cross-attn uses sigma (video sigma).
|
||||
video_ca_timestep = audio_sigma.flatten()
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
timestep.flatten(),
|
||||
video_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
video_cross_attn_a2v_gate, _ = self.av_cross_attn_video_a2v_gate(
|
||||
timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
video_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=hidden_states.dtype,
|
||||
)
|
||||
@@ -1239,13 +1509,14 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_ca_timestep = sigma.flatten()
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
audio_timestep.flatten(),
|
||||
audio_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
audio_cross_attn_v2a_gate, _ = self.av_cross_attn_audio_v2a_gate(
|
||||
audio_timestep.flatten() * timestep_cross_attn_gate_scale_factor,
|
||||
audio_ca_timestep * timestep_cross_attn_gate_scale_factor,
|
||||
batch_size=batch_size,
|
||||
hidden_dtype=audio_hidden_states.dtype,
|
||||
)
|
||||
@@ -1254,15 +1525,30 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
audio_cross_attn_v2a_gate = audio_cross_attn_v2a_gate.view(batch_size, -1, audio_cross_attn_v2a_gate.shape[-1])
|
||||
|
||||
# 4. Prepare prompt embeddings
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
# 4. Prepare prompt embeddings (LTX-2.0)
|
||||
if self.config.use_prompt_embeddings:
|
||||
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
|
||||
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.size(-1))
|
||||
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(batch_size, -1, audio_hidden_states.size(-1))
|
||||
audio_encoder_hidden_states = self.audio_caption_projection(audio_encoder_hidden_states)
|
||||
audio_encoder_hidden_states = audio_encoder_hidden_states.view(
|
||||
batch_size, -1, audio_hidden_states.size(-1)
|
||||
)
|
||||
|
||||
# 5. Run transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
spatio_temporal_guidance_blocks = spatio_temporal_guidance_blocks or []
|
||||
if len(spatio_temporal_guidance_blocks) > 0 and perturbation_mask is None:
|
||||
# If STG is being used and perturbation_mask is not supplied, default to perturbing all batch elements.
|
||||
perturbation_mask = torch.zeros((batch_size,))
|
||||
if perturbation_mask is not None and perturbation_mask.ndim == 1:
|
||||
perturbation_mask = perturbation_mask[:, None, None] # unsqueeze to 3D to broadcast with hidden_states
|
||||
all_perturbed = torch.all(perturbation_mask == 0) if perturbation_mask is not None else False
|
||||
stg_blocks = set(spatio_temporal_guidance_blocks)
|
||||
|
||||
for block_idx, block in enumerate(self.transformer_blocks):
|
||||
block_perturbation_mask = perturbation_mask if block_idx in stg_blocks else None
|
||||
block_all_perturbed = all_perturbed if block_idx in stg_blocks else False
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states, audio_hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
@@ -1276,12 +1562,22 @@ class LTX2VideoTransformer3DModel(
|
||||
audio_cross_attn_scale_shift,
|
||||
video_cross_attn_a2v_gate,
|
||||
audio_cross_attn_v2a_gate,
|
||||
temb_prompt,
|
||||
temb_prompt_audio,
|
||||
video_rotary_emb,
|
||||
audio_rotary_emb,
|
||||
video_cross_attn_rotary_emb,
|
||||
audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask,
|
||||
audio_encoder_attention_mask,
|
||||
None, # self_attention_mask
|
||||
None, # audio_self_attention_mask
|
||||
None, # a2v_cross_attention_mask
|
||||
None, # v2a_cross_attention_mask
|
||||
not isolate_modalities, # use_a2v_cross_attention
|
||||
not isolate_modalities, # use_v2a_cross_attention
|
||||
block_perturbation_mask,
|
||||
block_all_perturbed,
|
||||
)
|
||||
else:
|
||||
hidden_states, audio_hidden_states = block(
|
||||
@@ -1295,12 +1591,22 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_ca_audio_scale_shift=audio_cross_attn_scale_shift,
|
||||
temb_ca_gate=video_cross_attn_a2v_gate,
|
||||
temb_ca_audio_gate=audio_cross_attn_v2a_gate,
|
||||
temb_prompt=temb_prompt,
|
||||
temb_prompt_audio=temb_prompt_audio,
|
||||
video_rotary_emb=video_rotary_emb,
|
||||
audio_rotary_emb=audio_rotary_emb,
|
||||
ca_video_rotary_emb=video_cross_attn_rotary_emb,
|
||||
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
audio_encoder_attention_mask=audio_encoder_attention_mask,
|
||||
self_attention_mask=None,
|
||||
audio_self_attention_mask=None,
|
||||
a2v_cross_attention_mask=None,
|
||||
v2a_cross_attention_mask=None,
|
||||
use_a2v_cross_attention=not isolate_modalities,
|
||||
use_v2a_cross_attention=not isolate_modalities,
|
||||
perturbation_mask=block_perturbation_mask,
|
||||
all_perturbed=block_all_perturbed,
|
||||
)
|
||||
|
||||
# 6. Output layers (including unpatchification)
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import math
|
||||
from math import prod
|
||||
from typing import Any
|
||||
@@ -25,7 +24,7 @@ import torch.nn.functional as F
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, deprecate, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
@lru_cache_unless_export(maxsize=128)
|
||||
def _compute_video_freqs(
|
||||
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
|
||||
) -> torch.Tensor:
|
||||
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
|
||||
return vid_freqs, txt_freqs
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
|
||||
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
||||
return freqs.clone().contiguous()
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
@lru_cache_unless_export(maxsize=None)
|
||||
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
|
||||
seq_lens = frame * height * width
|
||||
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
|
||||
|
||||
@@ -70,6 +70,11 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = [
|
||||
"LTX2AutoBlocks",
|
||||
"LTX2Blocks",
|
||||
"LTX2ModularPipeline",
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
@@ -103,6 +108,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .ltx2 import LTX2AutoBlocks, LTX2Blocks, LTX2ModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
|
||||
52
src/diffusers/modular_pipelines/ltx2/__init__.py
Normal file
52
src/diffusers/modular_pipelines/ltx2/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_ltx2"] = ["LTX2Blocks", "LTX2AutoBlocks", "LTX2Stage1Blocks", "LTX2Stage2Blocks", "LTX2FullPipelineBlocks"]
|
||||
_import_structure["modular_blocks_ltx2_upsample"] = ["LTX2UpsampleBlocks", "LTX2UpsampleCoreBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"LTX2ModularPipeline",
|
||||
"LTX2UpsampleModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_ltx2 import LTX2AutoBlocks, LTX2Blocks, LTX2FullPipelineBlocks, LTX2Stage1Blocks, LTX2Stage2Blocks
|
||||
from .modular_blocks_ltx2_upsample import LTX2UpsampleBlocks, LTX2UpsampleCoreBlocks
|
||||
from .modular_pipeline import LTX2ModularPipeline, LTX2UpsampleModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
27
src/diffusers/modular_pipelines/ltx2/_checkpoint_utils.py
Normal file
27
src/diffusers/modular_pipelines/ltx2/_checkpoint_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Checkpoint utilities for parity debugging. No effect when _checkpoints is None."""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
save: bool = False
|
||||
stop: bool = False
|
||||
load: bool = False
|
||||
data: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def _maybe_checkpoint(checkpoints, name, data):
|
||||
if not checkpoints:
|
||||
return
|
||||
ckpt = checkpoints.get(name)
|
||||
if ckpt is None:
|
||||
return
|
||||
if ckpt.save:
|
||||
ckpt.data.update({
|
||||
k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data.items()
|
||||
})
|
||||
if ckpt.stop:
|
||||
raise StopIteration(name)
|
||||
657
src/diffusers/modular_pipelines/ltx2/before_denoise.py
Normal file
657
src/diffusers/modular_pipelines/ltx2/before_denoise.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from ...pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_num_frames, patch_size_t, post_patch_height, patch_size, post_patch_width, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
|
||||
def _pack_audio_latents(
|
||||
latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None
|
||||
) -> torch.Tensor:
|
||||
if patch_size is not None and patch_size_t is not None:
|
||||
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
|
||||
post_patch_latent_length = latent_length / patch_size_t
|
||||
post_patch_mel_bins = latent_mel_bins / patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
||||
else:
|
||||
latents = latents.transpose(1, 2).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
|
||||
class LTX2InputStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that determines batch_size and dtype, "
|
||||
"and expands embeddings for num_videos_per_prompt"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_videos_per_prompt", default=1),
|
||||
InputParam("connector_prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_audio_prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_attention_mask", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("connector_audio_negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("connector_negative_attention_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("batch_size", type_hint=int),
|
||||
OutputParam("dtype", type_hint=torch.dtype),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.connector_prompt_embeds.shape[0]
|
||||
block_state.dtype = components.transformer.dtype
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2SetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets up the scheduler timesteps for both video and audio denoising"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=40),
|
||||
InputParam("timesteps_input"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor),
|
||||
OutputParam("num_inference_steps", type_hint=int),
|
||||
OutputParam("audio_scheduler"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
timesteps_input = block_state.timesteps_input
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if sigmas is None:
|
||||
# Use torch.linspace (float32) to match reference scheduler precision.
|
||||
# np.linspace computes in float64 then casts to float32, which produces
|
||||
# values that differ by 1 ULP from torch's native float32 computation.
|
||||
sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].numpy()
|
||||
|
||||
mu = calculate_shift(
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_image_seq_len", 1024),
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_shift", 0.95),
|
||||
components.scheduler.config.get("max_shift", 2.05),
|
||||
)
|
||||
|
||||
audio_scheduler = copy.deepcopy(components.scheduler)
|
||||
_, _ = retrieve_timesteps(
|
||||
audio_scheduler, num_inference_steps, device, timesteps_input, sigmas=sigmas, mu=mu
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler, num_inference_steps, device, timesteps_input, sigmas=sigmas, mu=mu
|
||||
)
|
||||
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.audio_scheduler = audio_scheduler
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare video latents, optionally applying conditioning mask for I2V/conditional generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("noise_scale", default=1.0, type_hint=float),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
InputParam("condition_latents", type_hint=list),
|
||||
InputParam("condition_strengths", type_hint=list),
|
||||
InputParam("condition_indices", type_hint=list),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor),
|
||||
OutputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
OutputParam("clean_latents", type_hint=torch.Tensor),
|
||||
OutputParam("latent_num_frames", type_hint=int),
|
||||
OutputParam("latent_height", type_hint=int),
|
||||
OutputParam("latent_width", type_hint=int),
|
||||
OutputParam("video_sequence_length", type_hint=int),
|
||||
OutputParam("transformer_spatial_patch_size", type_hint=int),
|
||||
OutputParam("transformer_temporal_patch_size", type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
noise_scale = block_state.noise_scale
|
||||
generator = block_state.generator
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
transformer_spatial_patch_size = components.transformer.config.patch_size
|
||||
transformer_temporal_patch_size = components.transformer.config.patch_size_t
|
||||
num_channels_latents = components.transformer.config.in_channels
|
||||
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
|
||||
condition_latents = getattr(block_state, "condition_latents", None) or []
|
||||
condition_strengths = getattr(block_state, "condition_strengths", None) or []
|
||||
condition_indices = getattr(block_state, "condition_indices", None) or []
|
||||
has_conditions = len(condition_latents) > 0
|
||||
|
||||
if block_state.latents is not None:
|
||||
latents = block_state.latents
|
||||
if latents.ndim == 5:
|
||||
latents = _normalize_latents(
|
||||
latents, components.vae.latents_mean, components.vae.latents_std, components.vae.config.scaling_factor
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape
|
||||
latents = _pack_latents(latents, transformer_spatial_patch_size, transformer_temporal_patch_size)
|
||||
else:
|
||||
# Reference: create zeros in [B,C,F,H,W] in model dtype, pack to [B,S,D],
|
||||
# then generate noise in packed shape with same dtype
|
||||
latent_dtype = components.transformer.dtype
|
||||
shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
|
||||
latents = torch.zeros(shape, device=device, dtype=latent_dtype)
|
||||
latents = _pack_latents(latents, transformer_spatial_patch_size, transformer_temporal_patch_size)
|
||||
|
||||
conditioning_mask = None
|
||||
clean_latents = None
|
||||
|
||||
if has_conditions:
|
||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=torch.float32)
|
||||
conditioning_mask = _pack_latents(
|
||||
conditioning_mask, transformer_spatial_patch_size, transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
clean_latents = torch.zeros_like(latents)
|
||||
for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices):
|
||||
num_cond_tokens = cond.size(1)
|
||||
start_token_idx = latent_idx * latent_height * latent_width
|
||||
end_token_idx = start_token_idx + num_cond_tokens
|
||||
|
||||
latents[:, start_token_idx:end_token_idx] = cond
|
||||
conditioning_mask[:, start_token_idx:end_token_idx] = strength
|
||||
clean_latents[:, start_token_idx:end_token_idx] = cond
|
||||
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
|
||||
# Noise in packed [B,S,D] shape and same dtype as latent (matches reference GaussianNoiser)
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
scaled_mask = (1.0 - conditioning_mask) * noise_scale
|
||||
latents = noise * scaled_mask + latents * (1 - scaled_mask)
|
||||
else:
|
||||
# T2V: noise in packed shape, same dtype as latent
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
scaled_mask = noise_scale
|
||||
latents = noise * scaled_mask + latents * (1 - scaled_mask)
|
||||
|
||||
block_state.latents = latents
|
||||
block_state.conditioning_mask = conditioning_mask
|
||||
block_state.clean_latents = clean_latents
|
||||
block_state.latent_num_frames = latent_num_frames
|
||||
block_state.latent_height = latent_height
|
||||
block_state.latent_width = latent_width
|
||||
block_state.video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
block_state.transformer_spatial_patch_size = transformer_spatial_patch_size
|
||||
block_state.transformer_temporal_patch_size = transformer_temporal_patch_size
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareAudioLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare audio latents for the denoising process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("audio_vae", AutoencoderKLLTX2Audio),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
InputParam("noise_scale", default=1.0, type_hint=float),
|
||||
InputParam("audio_latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("audio_latents", type_hint=torch.Tensor),
|
||||
OutputParam("audio_num_frames", type_hint=int),
|
||||
OutputParam("latent_mel_bins", type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
num_frames = block_state.num_frames
|
||||
frame_rate = block_state.frame_rate
|
||||
noise_scale = block_state.noise_scale
|
||||
generator = block_state.generator
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
|
||||
audio_sampling_rate = components.audio_vae.config.sample_rate
|
||||
audio_hop_length = components.audio_vae.config.mel_hop_length
|
||||
audio_vae_temporal_compression_ratio = components.audio_vae.temporal_compression_ratio
|
||||
audio_vae_mel_compression_ratio = components.audio_vae.mel_compression_ratio
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = audio_sampling_rate / audio_hop_length / float(audio_vae_temporal_compression_ratio)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
|
||||
num_mel_bins = components.audio_vae.config.mel_bins
|
||||
latent_mel_bins = num_mel_bins // audio_vae_mel_compression_ratio
|
||||
num_channels_latents_audio = components.audio_vae.config.latent_channels
|
||||
|
||||
if block_state.audio_latents is not None:
|
||||
audio_latents = block_state.audio_latents
|
||||
if audio_latents.ndim == 4:
|
||||
_, _, audio_num_frames, _ = audio_latents.shape
|
||||
audio_latents = _pack_audio_latents(audio_latents)
|
||||
audio_latents = _normalize_audio_latents(
|
||||
audio_latents, components.audio_vae.latents_mean, components.audio_vae.latents_std
|
||||
)
|
||||
if noise_scale > 0.0:
|
||||
noise = randn_tensor(
|
||||
audio_latents.shape, generator=generator, device=audio_latents.device, dtype=audio_latents.dtype
|
||||
)
|
||||
audio_latents = noise_scale * noise + (1 - noise_scale) * audio_latents
|
||||
elif audio_latents.ndim == 3 and noise_scale > 0.0:
|
||||
noise = randn_tensor(
|
||||
audio_latents.shape, generator=generator, device=audio_latents.device, dtype=audio_latents.dtype
|
||||
)
|
||||
audio_latents = noise_scale * noise + (1 - noise_scale) * audio_latents
|
||||
audio_latents = audio_latents.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Reference: create zeros in [B,C,T,M] in model dtype, pack, then noise in packed shape
|
||||
latent_dtype = components.audio_vae.dtype
|
||||
shape = (batch_size, num_channels_latents_audio, audio_num_frames, latent_mel_bins)
|
||||
audio_latents = torch.zeros(shape, device=device, dtype=latent_dtype)
|
||||
audio_latents = _pack_audio_latents(audio_latents)
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
noise = randn_tensor(audio_latents.shape, generator=generator, device=device, dtype=latent_dtype)
|
||||
audio_latents = noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
|
||||
block_state.audio_latents = audio_latents
|
||||
block_state.audio_num_frames = audio_num_frames
|
||||
block_state.latent_mel_bins = latent_mel_bins
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareCoordinatesStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare video and audio RoPE coordinates for positional encoding"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("video_coords", type_hint=torch.Tensor),
|
||||
OutputParam("audio_coords", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
audio_latents = block_state.audio_latents
|
||||
frame_rate = block_state.frame_rate
|
||||
|
||||
video_coords = components.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0],
|
||||
block_state.latent_num_frames,
|
||||
block_state.latent_height,
|
||||
block_state.latent_width,
|
||||
latents.device,
|
||||
fps=frame_rate,
|
||||
)
|
||||
# Cast to latent dtype to match reference (positions stored in model dtype)
|
||||
video_coords = video_coords.to(latents.dtype)
|
||||
audio_coords = components.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], block_state.audio_num_frames, audio_latents.device
|
||||
)
|
||||
# Note: audio_coords already match reference dtype, no cast needed
|
||||
|
||||
block_state.video_coords = video_coords
|
||||
block_state.audio_coords = audio_coords
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2Stage2SetTimestepsStep(LTX2SetTimestepsStep):
|
||||
"""SetTimesteps for Stage 2: fixed distilled sigmas, no dynamic shifting."""
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Stage 2 timestep setup: uses fixed distilled sigmas with dynamic shifting disabled"
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=3),
|
||||
InputParam("timesteps_input"),
|
||||
InputParam("sigmas", default=list(STAGE_2_DISTILLED_SIGMA_VALUES)),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
components.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
components.scheduler.config,
|
||||
use_dynamic_shifting=False,
|
||||
shift_terminal=None,
|
||||
)
|
||||
return super().__call__(components, state)
|
||||
|
||||
|
||||
class LTX2Stage2PrepareLatentsStep(LTX2PrepareLatentsStep):
|
||||
"""PrepareLatents for Stage 2: noise_scale defaults to first distilled sigma value."""
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("noise_scale", default=STAGE_2_DISTILLED_SIGMA_VALUES[0], type_hint=float),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
InputParam("condition_latents", type_hint=list),
|
||||
InputParam("condition_strengths", type_hint=list),
|
||||
InputParam("condition_indices", type_hint=list),
|
||||
]
|
||||
|
||||
|
||||
class LTX2DisableAdapterStep(ModularPipelineBlocks):
|
||||
"""Disables LoRA adapters on transformer and connectors. No-op if no adapters are loaded."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Disable LoRA adapters before stage 1 (no-op if no adapters loaded)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return []
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
for model in [components.transformer, components.connectors]:
|
||||
if getattr(model, "_hf_peft_config_loaded", False):
|
||||
model.disable_adapters()
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2EnableAdapterStep(ModularPipelineBlocks):
|
||||
"""Enables LoRA adapters by name before stage 2. No-op if stage_2_adapter is not provided."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Enable LoRA adapters before stage 2 (no-op if stage_2_adapter not provided)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("stage_2_adapter", type_hint=str, description="Name of the LoRA adapter to enable for stage 2"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
adapter_name = block_state.stage_2_adapter
|
||||
if adapter_name is not None:
|
||||
for model in [components.transformer, components.connectors]:
|
||||
if getattr(model, "_hf_peft_config_loaded", False):
|
||||
model.enable_adapters()
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
228
src/diffusers/modular_pipelines/ltx2/decoders.py
Normal file
228
src/diffusers/modular_pipelines/ltx2/decoders.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
|
||||
def _unpack_audio_latents(
|
||||
latents: torch.Tensor,
|
||||
latent_length: int,
|
||||
num_mel_bins: int,
|
||||
patch_size: int | None = None,
|
||||
patch_size_t: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if patch_size is not None and patch_size_t is not None:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
else:
|
||||
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_audio_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents * latents_std) + latents_mean
|
||||
|
||||
|
||||
class LTX2VideoDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised video latents into video frames"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised video latents"),
|
||||
InputParam("output_type", default="np", type_hint=str, description="Output format: pil, np, pt, latent"),
|
||||
InputParam("decode_timestep", default=0.0, description="Timestep for VAE decode conditioning"),
|
||||
InputParam("decode_noise_scale", default=None, description="Noise scale for decode conditioning"),
|
||||
InputParam("generator", description="Random generator for reproducibility"),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
InputParam("transformer_spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("transformer_temporal_patch_size", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("videos", description="The decoded video frames"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
# Unpack latents from [B, S, D] -> [B, C, F, H, W]
|
||||
# Uses the transformer's patchify sizes (not the VAE's internal patch_size)
|
||||
latents = _unpack_latents(
|
||||
latents,
|
||||
block_state.latent_num_frames,
|
||||
block_state.latent_height,
|
||||
block_state.latent_width,
|
||||
block_state.transformer_spatial_patch_size,
|
||||
block_state.transformer_temporal_patch_size,
|
||||
)
|
||||
# Denormalize
|
||||
latents = _denormalize_latents(
|
||||
latents, components.vae.latents_mean, components.vae.latents_std, components.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.videos = latents
|
||||
else:
|
||||
latents = latents.to(block_state.dtype)
|
||||
device = latents.device
|
||||
|
||||
if not components.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(
|
||||
latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype
|
||||
)
|
||||
decode_timestep = block_state.decode_timestep
|
||||
decode_noise_scale = block_state.decode_noise_scale
|
||||
batch_size = block_state.batch_size
|
||||
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = latents.to(components.vae.dtype)
|
||||
video = components.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
video, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2AudioDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised audio latents into audio waveforms"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("audio_vae", AutoencoderKLLTX2Audio),
|
||||
ComponentSpec("vocoder", LTX2Vocoder),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor, description="Denoised audio latents"),
|
||||
InputParam("output_type", default="np", type_hint=str),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_mel_bins", required=True, type_hint=int),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("audio", description="The decoded audio waveforms"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
audio_latents = block_state.audio_latents
|
||||
|
||||
# Denormalize audio latents
|
||||
audio_latents = _denormalize_audio_latents(
|
||||
audio_latents, components.audio_vae.latents_mean, components.audio_vae.latents_std
|
||||
)
|
||||
# Unpack audio latents
|
||||
audio_latents = _unpack_audio_latents(
|
||||
audio_latents, block_state.audio_num_frames, num_mel_bins=block_state.latent_mel_bins
|
||||
)
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.audio = audio_latents
|
||||
else:
|
||||
audio_latents = audio_latents.to(components.audio_vae.dtype)
|
||||
generated_mel_spectrograms = components.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
# Squeeze batch dim and cast to float32 to match reference's decode_audio output format
|
||||
block_state.audio = components.vocoder(generated_mel_spectrograms).squeeze(0).float()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
490
src/diffusers/modular_pipelines/ltx2/denoise.py
Normal file
490
src/diffusers/modular_pipelines/ltx2/denoise.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LTX2LoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that prepares the latent inputs for the denoiser, "
|
||||
"including timestep masking for conditioned frames."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
InputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
block_state.audio_latent_model_input = block_state.audio_latents.to(block_state.dtype)
|
||||
|
||||
batch_size = block_state.latent_model_input.shape[0]
|
||||
num_video_tokens = block_state.latent_model_input.shape[1]
|
||||
num_audio_tokens = block_state.audio_latent_model_input.shape[1]
|
||||
|
||||
video_timestep = t.expand(batch_size, num_video_tokens)
|
||||
|
||||
if block_state.conditioning_mask is not None:
|
||||
block_state.video_timestep = video_timestep * (
|
||||
1 - block_state.conditioning_mask.squeeze(-1)
|
||||
)
|
||||
else:
|
||||
block_state.video_timestep = video_timestep
|
||||
|
||||
block_state.audio_timestep = t.expand(batch_size, num_audio_tokens)
|
||||
# Sigma for prompt_adaln: f32 to match reference's f32(sigma * scale_multiplier)
|
||||
block_state.sigma = torch.tensor([t.item()], dtype=torch.float32)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTX2LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: dict[str, Any] = None,
|
||||
guider_name: str = "guider",
|
||||
guider_config: FrozenDict = None,
|
||||
):
|
||||
"""Initialize a denoiser block for LTX2 that handles dual video+audio outputs.
|
||||
|
||||
Args:
|
||||
guider_input_fields: Dictionary mapping transformer argument names to block_state field names.
|
||||
Values can be tuples (conditional, unconditional) or strings (same for both).
|
||||
guider_name: Name of the guider component to use (default: "guider").
|
||||
guider_config: Config for the guider component (default: guidance_scale=4.0).
|
||||
"""
|
||||
self._guider_name = guider_name
|
||||
if guider_config is None:
|
||||
guider_config = FrozenDict({"guidance_scale": 4.0})
|
||||
self._guider_config = guider_config
|
||||
if guider_input_fields is None:
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
}
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
self._guider_name,
|
||||
ClassifierFreeGuidance,
|
||||
config=self._guider_config,
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that runs the transformer with guidance "
|
||||
"and handles dual video+audio output splitting. CFG is applied in x0 space "
|
||||
"to match the reference implementation."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
inputs = [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam("num_inference_steps", required=True, type_hint=int),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
InputParam("video_coords", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_coords", required=True, type_hint=torch.Tensor),
|
||||
InputParam("guidance_rescale", default=0.0, type_hint=float),
|
||||
InputParam("sigma", type_hint=torch.Tensor),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in set(guider_input_names):
|
||||
inputs.append(InputParam(name=name, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _convert_velocity_to_x0(sample, velocity, sigma):
|
||||
return sample - velocity * sigma
|
||||
|
||||
@staticmethod
|
||||
def _convert_x0_to_velocity(sample, x0, sigma):
|
||||
return (sample - x0) / sigma
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider = getattr(components, self._guider_name)
|
||||
guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
guider_state = guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
use_cross_timestep = getattr(components.transformer.config, "use_cross_timestep", False)
|
||||
sigma_val = components.scheduler.sigmas[i]
|
||||
|
||||
# Pass raw sigma to wrapper if available (avoids timestep/1000 round-trip precision loss)
|
||||
if hasattr(components.transformer, "_raw_sigma"):
|
||||
components.transformer._raw_sigma = sigma_val
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
# Drop all-ones attention masks — they're functionally no-op but trigger
|
||||
# a different SDPA kernel path (masked vs unmasked) with different bf16 rounding.
|
||||
# Reference passes context_mask=None for unmasked attention.
|
||||
for mask_key in ["encoder_attention_mask", "audio_encoder_attention_mask"]:
|
||||
mask = cond_kwargs.get(mask_key)
|
||||
if mask is not None and mask.ndim <= 2 and (mask == 1).all():
|
||||
cond_kwargs[mask_key] = None
|
||||
|
||||
video_timestep = block_state.video_timestep
|
||||
audio_timestep = block_state.audio_timestep
|
||||
|
||||
with components.transformer.cache_context("cond_uncond"):
|
||||
noise_pred_video, noise_pred_audio = components.transformer(
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
audio_hidden_states=block_state.audio_latent_model_input.to(block_state.dtype),
|
||||
timestep=video_timestep,
|
||||
audio_timestep=audio_timestep,
|
||||
sigma=block_state.sigma,
|
||||
num_frames=block_state.latent_num_frames,
|
||||
height=block_state.latent_height,
|
||||
width=block_state.latent_width,
|
||||
fps=block_state.frame_rate,
|
||||
audio_num_frames=block_state.audio_num_frames,
|
||||
video_coords=block_state.video_coords,
|
||||
audio_coords=block_state.audio_coords,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)
|
||||
|
||||
# Convert to x0 for guidance.
|
||||
prediction_type = getattr(components.transformer, "prediction_type", "velocity")
|
||||
if prediction_type == "x0":
|
||||
# Model already outputs x0 — no conversion needed
|
||||
x0_video = noise_pred_video
|
||||
x0_audio = noise_pred_audio
|
||||
else:
|
||||
# Model outputs velocity — convert to x0 matching reference's to_denoised:
|
||||
# (sample.f32 - velocity.f32 * sigma_f32).to(sample.dtype)
|
||||
# Reference uses f32 sigma (from denoise_mask * sigma, both f32).
|
||||
x0_video = self._convert_velocity_to_x0(
|
||||
block_state.latents.float(), noise_pred_video.float(), sigma_val
|
||||
).to(block_state.latents.dtype)
|
||||
x0_audio = self._convert_velocity_to_x0(
|
||||
block_state.audio_latents.float(), noise_pred_audio.float(), sigma_val
|
||||
).to(block_state.audio_latents.dtype)
|
||||
|
||||
guider_state_batch.noise_pred = x0_video
|
||||
guider_state_batch.noise_pred_audio = x0_audio
|
||||
|
||||
# Sub-step checkpoint: save/load x0 per condition
|
||||
_ckpts = getattr(block_state, "_checkpoints", None)
|
||||
if _ckpts:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
cond_label = "cond" if guider_state_batch is guider_state[0] else "uncond"
|
||||
_maybe_checkpoint(_ckpts, f"step_{i}_{cond_label}_x0", {
|
||||
"video": x0_video, "audio": x0_audio,
|
||||
})
|
||||
# Load support: inject reference x0 for this condition
|
||||
ckpt = _ckpts.get(f"step_{i}_{cond_label}_x0")
|
||||
if ckpt is not None and ckpt.load:
|
||||
x0_video = ckpt.data["video"].to(x0_video)
|
||||
x0_audio = ckpt.data["audio"].to(x0_audio)
|
||||
guider_state_batch.noise_pred = x0_video
|
||||
guider_state_batch.noise_pred_audio = x0_audio
|
||||
|
||||
guider.cleanup_models(components.transformer)
|
||||
|
||||
# Apply guidance in x0 space using reference formula:
|
||||
# cond + (scale - 1) * (cond - uncond)
|
||||
# This is mathematically equivalent to uncond + scale * (cond - uncond)
|
||||
# but produces different bf16 rounding.
|
||||
if len(guider_state) == 2:
|
||||
guidance_scale = guider.guidance_scale
|
||||
x0_video_cond = guider_state[0].noise_pred
|
||||
x0_video_uncond = guider_state[1].noise_pred
|
||||
guided_x0_video = x0_video_cond + (guidance_scale - 1) * (x0_video_cond - x0_video_uncond)
|
||||
|
||||
x0_audio_cond = guider_state[0].noise_pred_audio
|
||||
x0_audio_uncond = guider_state[1].noise_pred_audio
|
||||
guided_x0_audio = x0_audio_cond + (guidance_scale - 1) * (x0_audio_cond - x0_audio_uncond)
|
||||
|
||||
if block_state.guidance_rescale > 0:
|
||||
guided_x0_video = self._rescale_noise_cfg(
|
||||
guided_x0_video,
|
||||
guider_state[0].noise_pred,
|
||||
block_state.guidance_rescale,
|
||||
)
|
||||
guided_x0_audio = self._rescale_noise_cfg(
|
||||
guided_x0_audio,
|
||||
x0_audio_cond,
|
||||
block_state.guidance_rescale,
|
||||
)
|
||||
else:
|
||||
guided_x0_video = guider_state[0].noise_pred
|
||||
guided_x0_audio = guider_state[0].noise_pred_audio
|
||||
|
||||
# Sub-step checkpoint: save/load guided x0
|
||||
_ckpts = getattr(block_state, "_checkpoints", None)
|
||||
if _ckpts:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
_maybe_checkpoint(_ckpts, f"step_{i}_guided_x0", {
|
||||
"video": guided_x0_video, "audio": guided_x0_audio,
|
||||
})
|
||||
# Load support: inject reference guided x0
|
||||
ckpt = _ckpts.get(f"step_{i}_guided_x0")
|
||||
if ckpt is not None and ckpt.load:
|
||||
guided_x0_video = ckpt.data["video"].to(guided_x0_video)
|
||||
guided_x0_audio = ckpt.data["audio"].to(guided_x0_audio)
|
||||
|
||||
# Convert guided x0 back to velocity for the scheduler.
|
||||
# Use sigma_val.item() (Python float) to match reference's to_velocity which
|
||||
# does sigma.to(float32).item() — dividing by Python float vs 0-dim tensor
|
||||
# uses different CUDA kernels and can produce different results at specific values.
|
||||
sigma_scalar = sigma_val.item()
|
||||
block_state.noise_pred_video = self._convert_x0_to_velocity(
|
||||
block_state.latents.float(), guided_x0_video, sigma_scalar
|
||||
).to(block_state.latents.dtype)
|
||||
block_state.noise_pred_audio = self._convert_x0_to_velocity(
|
||||
block_state.audio_latents.float(), guided_x0_audio, sigma_scalar
|
||||
).to(block_state.audio_latents.dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
@staticmethod
|
||||
def _rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class LTX2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that updates latents via scheduler step, "
|
||||
"with optional x0-space conditioning blending."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
InputParam("clean_latents", type_hint=torch.Tensor),
|
||||
InputParam("audio_scheduler", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
noise_pred_video = block_state.noise_pred_video
|
||||
noise_pred_audio = block_state.noise_pred_audio
|
||||
|
||||
if block_state.conditioning_mask is not None:
|
||||
# x0 blending: convert velocity to x0, blend with clean latents, convert back
|
||||
sigma = components.scheduler.sigmas[i]
|
||||
denoised_sample = block_state.latents - noise_pred_video * sigma
|
||||
|
||||
bsz = noise_pred_video.size(0)
|
||||
conditioning_mask = block_state.conditioning_mask[:bsz]
|
||||
clean_latents = block_state.clean_latents
|
||||
|
||||
denoised_sample_cond = (
|
||||
denoised_sample * (1 - conditioning_mask) + clean_latents.float() * conditioning_mask
|
||||
).to(noise_pred_video.dtype)
|
||||
|
||||
denoised_latents_cond = ((block_state.latents - denoised_sample_cond) / sigma).to(
|
||||
noise_pred_video.dtype
|
||||
)
|
||||
block_state.latents = components.scheduler.step(
|
||||
denoised_latents_cond, t, block_state.latents, return_dict=False
|
||||
)[0]
|
||||
else:
|
||||
block_state.latents = components.scheduler.step(
|
||||
noise_pred_video, t, block_state.latents, return_dict=False
|
||||
)[0]
|
||||
|
||||
block_state.audio_latents = block_state.audio_scheduler.step(
|
||||
noise_pred_audio, t, block_state.audio_latents, return_dict=False
|
||||
)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTX2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Pipeline block that iteratively denoises the latents over timesteps for LTX2"
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("timesteps", required=True, type_hint=torch.Tensor),
|
||||
InputParam("num_inference_steps", required=True, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
_checkpoints = state.get("_checkpoints")
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
# Checkpoint: save/load preloop state
|
||||
if _checkpoints:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {
|
||||
"video_latent": block_state.latents, "audio_latent": block_state.audio_latents,
|
||||
})
|
||||
if "preloop" in _checkpoints and _checkpoints["preloop"].load:
|
||||
d = _checkpoints["preloop"].data
|
||||
block_state.latents = d["video_latent"].to(block_state.latents)
|
||||
block_state.audio_latents = d["audio_latent"].to(block_state.audio_latents)
|
||||
|
||||
# Pass _checkpoints to sub-blocks via block_state
|
||||
if _checkpoints:
|
||||
block_state._checkpoints = _checkpoints
|
||||
|
||||
try:
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
|
||||
# Checkpoint: save velocity (= guided prediction) after denoiser, before scheduler
|
||||
if _checkpoints:
|
||||
_maybe_checkpoint(_checkpoints, f"step_{i}_velocity", {
|
||||
"video": block_state.noise_pred_video, "audio": block_state.noise_pred_audio,
|
||||
})
|
||||
|
||||
# Checkpoint: save/load after each step
|
||||
if _checkpoints:
|
||||
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
|
||||
"video_latent": block_state.latents, "audio_latent": block_state.audio_latents,
|
||||
})
|
||||
if f"after_step_{i}" in _checkpoints and _checkpoints[f"after_step_{i}"].load:
|
||||
d = _checkpoints[f"after_step_{i}"].data
|
||||
block_state.latents = d["video_latent"].to(block_state.latents)
|
||||
block_state.audio_latents = d["audio_latent"].to(block_state.audio_latents)
|
||||
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2DenoiseStep(LTX2DenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
LTX2LoopBeforeDenoiser,
|
||||
LTX2LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
}
|
||||
),
|
||||
LTX2LoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises video and audio latents.\n"
|
||||
"At each iteration, it runs:\n"
|
||||
" - LTX2LoopBeforeDenoiser (prepare inputs, timestep masking)\n"
|
||||
" - LTX2LoopDenoiser (transformer forward + guidance)\n"
|
||||
" - LTX2LoopAfterDenoiser (scheduler step + x0 blending)\n"
|
||||
"Supports T2V, I2V, and conditional generation."
|
||||
)
|
||||
541
src/diffusers/modular_pipelines/ltx2/encoders.py
Normal file
541
src/diffusers/modular_pipelines/ltx2/encoders.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2VideoCondition:
|
||||
"""
|
||||
Defines a single frame-conditioning item for LTX-2 Video.
|
||||
|
||||
Attributes:
|
||||
frames: The image (or video) to condition on.
|
||||
index: The latent index at which to insert the condition.
|
||||
strength: The strength of the conditioning effect (0-1).
|
||||
"""
|
||||
|
||||
frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor
|
||||
index: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
mask = token_indices < sequence_lengths[:, None]
|
||||
elif padding_side == "left":
|
||||
start_indices = seq_len - sequence_lengths[:, None]
|
||||
mask = token_indices >= start_indices
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None]
|
||||
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_num_frames, patch_size_t, post_patch_height, patch_size, post_patch_width, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
|
||||
class LTX2TextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text encoder step that encodes prompts using Gemma3 for LTX2 video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Gemma3ForConditionalGeneration),
|
||||
ComponentSpec("tokenizer", GemmaTokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Packed text embeddings from Gemma3",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Packed negative text embeddings from Gemma3",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask for prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask for negative prompt embeddings",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_gemma_prompt_embeds(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: str | list[str],
|
||||
max_sequence_length: int = 1024,
|
||||
scale_factor: int = 8,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
dtype = dtype or text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
prompt = [p.strip() for p in prompt]
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
text_encoder_outputs = text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
# Return raw stacked hidden states [B, T, D, L] — the connector handles normalization
|
||||
# (per_token_rms_norm + rescaling for LTX-2.3, or _pack_text_embeds for LTX-2.0)
|
||||
prompt_embeds = text_encoder_hidden_states.to(dtype=dtype)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
device = device or components._execution_device
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt_embeds, prompt_attention_mask = LTX2TextEncoderStep._get_gemma_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = LTX2TextEncoderStep._get_gemma_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
(
|
||||
block_state.prompt_embeds,
|
||||
block_state.prompt_attention_mask,
|
||||
block_state.negative_prompt_embeds,
|
||||
block_state.negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2ConnectorStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Connector step that transforms text embeddings into video and audio conditioning embeddings"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("prompt_attention_mask", required=True, type_hint=torch.Tensor),
|
||||
InputParam("negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("negative_prompt_attention_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"connector_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Video text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_audio_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Audio text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative video text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative audio text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_negative_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative attention mask from connector",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
prompt_attention_mask = block_state.prompt_attention_mask
|
||||
negative_prompt_embeds = block_state.negative_prompt_embeds
|
||||
negative_prompt_attention_mask = block_state.negative_prompt_attention_mask
|
||||
|
||||
do_cfg = negative_prompt_embeds is not None
|
||||
|
||||
if do_cfg:
|
||||
combined_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
combined_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
else:
|
||||
combined_embeds = prompt_embeds
|
||||
combined_mask = prompt_attention_mask
|
||||
|
||||
connector_embeds, connector_audio_embeds, connector_mask = components.connectors(
|
||||
combined_embeds, combined_mask
|
||||
)
|
||||
|
||||
if do_cfg:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
block_state.connector_negative_prompt_embeds = connector_embeds[:batch_size]
|
||||
block_state.connector_prompt_embeds = connector_embeds[batch_size:]
|
||||
block_state.connector_audio_negative_prompt_embeds = connector_audio_embeds[:batch_size]
|
||||
block_state.connector_audio_prompt_embeds = connector_audio_embeds[batch_size:]
|
||||
block_state.connector_negative_attention_mask = connector_mask[:batch_size]
|
||||
block_state.connector_attention_mask = connector_mask[batch_size:]
|
||||
else:
|
||||
block_state.connector_prompt_embeds = connector_embeds
|
||||
block_state.connector_audio_prompt_embeds = connector_audio_embeds
|
||||
block_state.connector_attention_mask = connector_mask
|
||||
block_state.connector_negative_prompt_embeds = None
|
||||
block_state.connector_audio_negative_prompt_embeds = None
|
||||
block_state.connector_negative_attention_mask = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2ConditionEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Condition encoder step that VAE-encodes conditioning frames for I2V and conditional generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("conditions", type_hint=list, description="List of LTX2VideoCondition objects"),
|
||||
InputParam("image", type_hint=PIL.Image.Image, description="Sugar for I2V: image to condition on frame 0"),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("condition_latents", type_hint=list, description="List of packed condition latent tensors"),
|
||||
OutputParam("condition_strengths", type_hint=list, description="List of conditioning strengths"),
|
||||
OutputParam("condition_indices", type_hint=list, description="List of latent frame indices"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
conditions = block_state.conditions
|
||||
image = block_state.image
|
||||
|
||||
# Convert image sugar to conditions list
|
||||
if image is not None and conditions is None:
|
||||
conditions = [LTX2VideoCondition(frames=image, index=0, strength=1.0)]
|
||||
|
||||
if conditions is None:
|
||||
block_state.condition_latents = []
|
||||
block_state.condition_strengths = []
|
||||
block_state.condition_indices = []
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
if isinstance(conditions, LTX2VideoCondition):
|
||||
conditions = [conditions]
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
device = components._execution_device
|
||||
generator = block_state.generator
|
||||
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
transformer_spatial_patch_size = 1
|
||||
transformer_temporal_patch_size = 1
|
||||
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
|
||||
conditioning_frames, conditioning_strengths, conditioning_indices = [], [], []
|
||||
|
||||
for i, condition in enumerate(conditions):
|
||||
if isinstance(condition.frames, PIL.Image.Image):
|
||||
video_like_cond = [condition.frames]
|
||||
elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3:
|
||||
video_like_cond = np.expand_dims(condition.frames, axis=0)
|
||||
elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3:
|
||||
video_like_cond = condition.frames.unsqueeze(0)
|
||||
else:
|
||||
video_like_cond = condition.frames
|
||||
|
||||
condition_pixels = components.video_processor.preprocess_video(
|
||||
video_like_cond, height, width, resize_mode="crop"
|
||||
)
|
||||
|
||||
latent_start_idx = condition.index
|
||||
if latent_start_idx < 0:
|
||||
latent_start_idx = latent_start_idx % latent_num_frames
|
||||
if latent_start_idx >= latent_num_frames:
|
||||
logger.warning(
|
||||
f"The starting latent index {latent_start_idx} of condition {i} is too big for {latent_num_frames} "
|
||||
f"latent frames. This condition will be skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
cond_num_frames = condition_pixels.size(2)
|
||||
start_idx = max((latent_start_idx - 1) * vae_temporal_compression_ratio + 1, 0)
|
||||
frame_num = min(cond_num_frames, num_frames - start_idx)
|
||||
frame_num = (frame_num - 1) // vae_temporal_compression_ratio * vae_temporal_compression_ratio + 1
|
||||
condition_pixels = condition_pixels[:, :, :frame_num]
|
||||
|
||||
conditioning_frames.append(condition_pixels.to(dtype=components.vae.dtype, device=device))
|
||||
conditioning_strengths.append(condition.strength)
|
||||
conditioning_indices.append(latent_start_idx)
|
||||
|
||||
condition_latents = []
|
||||
for condition_tensor in conditioning_frames:
|
||||
condition_latent = retrieve_latents(
|
||||
components.vae.encode(condition_tensor), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = _normalize_latents(
|
||||
condition_latent, components.vae.latents_mean, components.vae.latents_std
|
||||
).to(device=device, dtype=torch.float32)
|
||||
condition_latent = _pack_latents(
|
||||
condition_latent, transformer_spatial_patch_size, transformer_temporal_patch_size
|
||||
)
|
||||
condition_latents.append(condition_latent)
|
||||
|
||||
block_state.condition_latents = condition_latents
|
||||
block_state.condition_strengths = conditioning_strengths
|
||||
block_state.condition_indices = conditioning_indices
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
436
src/diffusers/modular_pipelines/ltx2/modular_blocks_ltx2.py
Normal file
436
src/diffusers/modular_pipelines/ltx2/modular_blocks_ltx2.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import ComponentSpec, OutputParam
|
||||
from .before_denoise import (
|
||||
LTX2DisableAdapterStep,
|
||||
LTX2EnableAdapterStep,
|
||||
LTX2InputStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2PrepareLatentsStep,
|
||||
LTX2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import LTX2AudioDecoderStep, LTX2VideoDecoderStep
|
||||
from .denoise import LTX2DenoiseLoopWrapper, LTX2DenoiseStep, LTX2LoopAfterDenoiser, LTX2LoopBeforeDenoiser, LTX2LoopDenoiser
|
||||
from .encoders import LTX2ConditionEncoderStep, LTX2ConnectorStep, LTX2TextEncoderStep
|
||||
from .modular_blocks_ltx2_upsample import LTX2UpsampleCoreBlocks
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. AUTO CONDITION ENCODER (skip if no conditions)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2AutoConditionEncoderStep(AutoPipelineBlocks):
|
||||
"""Auto block that runs condition encoding when conditions or image inputs are provided.
|
||||
|
||||
- When `conditions` is provided: runs condition encoder for arbitrary frame conditioning
|
||||
- When `image` is provided: runs condition encoder (converts image to condition at frame 0)
|
||||
- When neither is provided: step is skipped (T2V mode)
|
||||
"""
|
||||
|
||||
block_classes = [LTX2ConditionEncoderStep, LTX2ConditionEncoderStep]
|
||||
block_names = ["conditional_encoder", "image_encoder"]
|
||||
block_trigger_inputs = ["conditions", "image"]
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoising block: input prep -> timesteps -> latents -> audio latents -> coordinates -> denoise loop."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2SetTimestepsStep,
|
||||
LTX2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2DenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise block that takes encoded conditions and runs the full denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. BLOCKS (T2V only)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Blocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 text-to-video generation."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX2 text-to-video generation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS (T2V + I2V + Conditional)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2AutoBlocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 with unified T2V, I2V, and conditional generation.
|
||||
|
||||
Workflow map:
|
||||
- text2video: prompt only
|
||||
- image2video: image + prompt (auto-converts to condition at frame 0)
|
||||
- conditional: conditions + prompt (arbitrary frame conditioning)
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2AutoConditionEncoderStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "condition_encoder", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Unified modular pipeline blocks for LTX2 supporting text-to-video, "
|
||||
"image-to-video, and conditional/FLF2V generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. STAGE 2 CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoise for Stage 2: uses distilled sigmas with no dynamic shifting."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2DenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 core denoise block using distilled sigmas and no dynamic shifting."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 6. STAGE 1 BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage1Blocks(SequentialPipelineBlocks):
|
||||
"""Stage 1 blocks: text encoding -> conditioning -> denoise -> latent output.
|
||||
|
||||
Outputs latents and audio_latents for downstream processing (upsample + stage2).
|
||||
Supports T2V, I2V, and conditional generation modes.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2AutoConditionEncoderStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "condition_encoder", "denoise"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Stage 1 modular pipeline blocks for LTX2: text encoding, conditioning, "
|
||||
"and denoising. Outputs latents for upsample + stage2 workflow."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latents"), OutputParam("audio_latents")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 7. STAGE 2 BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2Blocks(SequentialPipelineBlocks):
|
||||
"""Stage 2 blocks: text encoding -> denoise (distilled) -> decode video + audio.
|
||||
|
||||
Expects pre-computed latents (from upsample) and audio_latents (from stage1).
|
||||
Uses fixed distilled sigmas with no dynamic shifting and guidance_scale=1.0.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2Stage2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Stage 2 modular pipeline blocks for LTX2: re-encodes text, "
|
||||
"denoises with distilled sigmas, and decodes video + audio."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
# Override guider default for stage2: guidance_scale=1.0 (no CFG)
|
||||
components = [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 1.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
for block in self.sub_blocks.values():
|
||||
for component in block.expected_components:
|
||||
if component not in components:
|
||||
components.append(component)
|
||||
return components
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 8. STAGE 2 FULL DENOISE (uses stage2_guider)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2FullDenoiseStep(LTX2DenoiseLoopWrapper):
|
||||
"""Denoise step for Stage 2 within the full pipeline, using stage2_guider (guidance_scale=1.0)."""
|
||||
|
||||
block_classes = [
|
||||
LTX2LoopBeforeDenoiser,
|
||||
LTX2LoopDenoiser(
|
||||
guider_name="stage2_guider",
|
||||
guider_config=FrozenDict({"guidance_scale": 1.0}),
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
},
|
||||
),
|
||||
LTX2LoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Stage 2 denoise step using stage2_guider (guidance_scale=1.0).\n"
|
||||
"Used within LTX2FullPipelineBlocks to avoid conflict with the Stage 1 guider."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 9. STAGE 2 FULL CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2FullCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoise for Stage 2 within the full pipeline: distilled sigmas, no dynamic shifting, stage2_guider."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2Stage2FullDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 core denoise for full pipeline: distilled sigmas, no dynamic shifting, stage2_guider."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 10. STAGE 2 INTERNAL BLOCKS (no text encoder/connector)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2InternalBlocks(SequentialPipelineBlocks):
|
||||
"""Stage 2 blocks without text encoder/connector — reads connector_* embeddings from state.
|
||||
|
||||
Used within LTX2FullPipelineBlocks where Stage 1 already encoded text.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2Stage2FullCoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 internal blocks (no text encoding): denoise with stage2_guider + decode."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 11. FULL PIPELINE BLOCKS (all-in-one)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2FullPipelineBlocks(SequentialPipelineBlocks):
|
||||
"""All-in-one mega-block: stage1 -> upsample -> stage2 in a single pipe() call.
|
||||
|
||||
LoRA adapters are automatically disabled for stage1 and re-enabled for stage2.
|
||||
Uses two guiders: `guider` (guidance_scale=4.0) for stage1 and
|
||||
`stage2_guider` (guidance_scale=1.0) for stage2.
|
||||
|
||||
Required components: text_encoder, tokenizer, transformer, connectors, vae, audio_vae,
|
||||
vocoder, scheduler, guider, stage2_guider, latent_upsampler.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2DisableAdapterStep,
|
||||
LTX2Stage1Blocks,
|
||||
LTX2UpsampleCoreBlocks,
|
||||
LTX2EnableAdapterStep,
|
||||
LTX2Stage2InternalBlocks,
|
||||
]
|
||||
block_names = ["disable_lora", "stage1", "upsample", "enable_lora", "stage2"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"All-in-one LTX2 pipeline: stage1 (denoise) -> upsample -> stage2 (distilled denoise + decode). "
|
||||
"LoRA adapters toggled automatically via stage_2_adapter parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
@@ -0,0 +1,373 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
|
||||
class LTX2UpsamplePrepareStep(ModularPipelineBlocks):
|
||||
"""Prepare latents for upsampling: accepts either video frames or pre-computed latents."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents for the latent upsampler, from either video input or pre-computed latents"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("video", description="Video frames to encode and upsample"),
|
||||
InputParam("latents", type_hint=torch.Tensor, description="Pre-computed latents to upsample"),
|
||||
InputParam("latents_normalized", default=False, type_hint=bool),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("temporal_patch_size", default=1, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="Prepared latents for upsampling"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
video = block_state.video
|
||||
latents = block_state.latents
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
generator = block_state.generator
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
latents = _unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
block_state.spatial_patch_size,
|
||||
block_state.temporal_patch_size,
|
||||
)
|
||||
if block_state.latents_normalized:
|
||||
latents = _denormalize_latents(
|
||||
latents,
|
||||
components.vae.latents_mean,
|
||||
components.vae.latents_std,
|
||||
components.vae.config.scaling_factor,
|
||||
)
|
||||
block_state.latents = latents.to(device=device, dtype=torch.float32)
|
||||
elif video is not None:
|
||||
if isinstance(video, list):
|
||||
num_frames = len(video)
|
||||
if num_frames % vae_temporal_compression_ratio != 1:
|
||||
num_frames = num_frames // vae_temporal_compression_ratio * vae_temporal_compression_ratio + 1
|
||||
if isinstance(video, list):
|
||||
video = video[:num_frames]
|
||||
|
||||
video = components.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(components.vae.encode(vid.unsqueeze(0)), generator) for vid in video
|
||||
]
|
||||
block_state.latents = torch.cat(init_latents, dim=0).to(torch.float32)
|
||||
else:
|
||||
raise ValueError("One of `video` or `latents` must be provided.")
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2UpsampleStep(ModularPipelineBlocks):
|
||||
"""Run the latent upsampler model with optional AdaIN and tone mapping."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Run the latent upsampler model with optional AdaIN filtering and tone mapping"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("latent_upsampler", LTX2LatentUpsamplerModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("adain_factor", default=0.0, type_hint=float),
|
||||
InputParam("tone_map_compression_ratio", default=0.0, type_hint=float),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="Upsampled latents"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
result = latents.clone()
|
||||
for i in range(latents.size(0)):
|
||||
for c in range(latents.size(1)):
|
||||
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
|
||||
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
|
||||
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
|
||||
result = torch.lerp(latents, result, factor)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tone_map_latents(latents: torch.Tensor, compression: float) -> torch.Tensor:
|
||||
scale_factor = compression * 0.75
|
||||
abs_latents = torch.abs(latents)
|
||||
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||
return latents * scales
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents.to(components.latent_upsampler.dtype)
|
||||
reference_latents = latents
|
||||
|
||||
latents_upsampled = components.latent_upsampler(latents)
|
||||
|
||||
if block_state.adain_factor > 0.0:
|
||||
latents = self.adain_filter_latent(latents_upsampled, reference_latents, block_state.adain_factor)
|
||||
else:
|
||||
latents = latents_upsampled
|
||||
|
||||
if block_state.tone_map_compression_ratio > 0.0:
|
||||
latents = self.tone_map_latents(latents, block_state.tone_map_compression_ratio)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2UpsamplePostprocessStep(ModularPipelineBlocks):
|
||||
"""Decode upsampled latents to video frames."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Decode upsampled latents into video frames"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
InputParam("decode_timestep", default=0.0),
|
||||
InputParam("decode_noise_scale", default=None),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("videos", description="Decoded video frames"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.videos = latents
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = latents.device
|
||||
|
||||
if not components.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(
|
||||
latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype
|
||||
)
|
||||
decode_timestep = block_state.decode_timestep
|
||||
decode_noise_scale = block_state.decode_noise_scale
|
||||
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = components.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
video, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# ====================
|
||||
# UPSAMPLE BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2UpsampleBlocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 latent upsampling."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2UpsamplePrepareStep,
|
||||
LTX2UpsampleStep,
|
||||
LTX2UpsamplePostprocessStep,
|
||||
]
|
||||
block_names = ["prepare", "upsample", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX2 latent upsampling (stage1 -> upsample -> stage2)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos")]
|
||||
|
||||
|
||||
class LTX2UpsampleCorePrepareStep(LTX2UpsamplePrepareStep):
|
||||
"""Upsample prepare step for the full pipeline: latents_normalized defaults to True."""
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("video", description="Video frames to encode and upsample"),
|
||||
InputParam("latents", type_hint=torch.Tensor, description="Pre-computed latents to upsample"),
|
||||
InputParam("latents_normalized", default=True, type_hint=bool),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("temporal_patch_size", default=1, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
|
||||
class LTX2UpsampleCoreBlocks(SequentialPipelineBlocks):
|
||||
"""Upsample blocks for the full pipeline: prepare + upsample only (no decode).
|
||||
|
||||
Outputs 5D latents (not decoded video), suitable for chaining into Stage2.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2UpsampleCorePrepareStep,
|
||||
LTX2UpsampleStep,
|
||||
]
|
||||
block_names = ["prepare", "upsample"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Latent upsample blocks (no decode) for use within the full pipeline."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latents")]
|
||||
112
src/diffusers/modular_pipelines/ltx2/modular_pipeline.py
Normal file
112
src/diffusers/modular_pipelines/ltx2/modular_pipeline.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...loaders import LTX2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LTX2ModularPipeline(
|
||||
ModularPipeline,
|
||||
LTX2LoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for LTX2 video generation (T2V, I2V, Conditional/FLF2V).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "LTX2AutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 768
|
||||
|
||||
@property
|
||||
def default_num_frames(self):
|
||||
return 121
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 32
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.spatial_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def vae_scale_factor_temporal(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.temporal_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def transformer_spatial_patch_size(self):
|
||||
patch_size = 1
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size = self.transformer.config.patch_size
|
||||
return patch_size
|
||||
|
||||
@property
|
||||
def transformer_temporal_patch_size(self):
|
||||
patch_size = 1
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size = self.transformer.config.patch_size_t
|
||||
return patch_size
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires = self.guider._enabled and self.guider.num_conditions > 1
|
||||
return requires
|
||||
|
||||
|
||||
class LTX2UpsampleModularPipeline(ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for LTX2 latent upsampling.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "LTX2UpsampleBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 768
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 32
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.spatial_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def vae_scale_factor_temporal(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.temporal_compression_ratio
|
||||
return vae_scale_factor
|
||||
@@ -132,6 +132,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("helios", _create_default_map_fn("HeliosModularPipeline")),
|
||||
("helios-pyramid", _helios_pyramid_map_fn),
|
||||
("ltx2", _create_default_map_fn("LTX2ModularPipeline")),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
The sequence of generated hidden-states.
|
||||
"""
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
if hasattr(self.language_model, "_get_initial_cache_position"):
|
||||
cache_position_kwargs = {}
|
||||
if is_transformers_version("<", "4.52.1"):
|
||||
cache_position_kwargs["input_ids"] = inputs_embeds
|
||||
else:
|
||||
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
|
||||
cache_position_kwargs["device"] = (
|
||||
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
|
||||
)
|
||||
cache_position_kwargs["model_kwargs"] = model_kwargs
|
||||
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
# prepare model inputs
|
||||
|
||||
@@ -123,6 +123,7 @@ from .stable_diffusion_xl import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2Pipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
from .z_image import (
|
||||
@@ -247,6 +248,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("ltx2", LTX2Pipeline),
|
||||
("wan", WanPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -720,6 +720,7 @@ class LDMBertModel(LDMBertPreTrainedModel):
|
||||
super().__init__(config)
|
||||
self.model = LDMBertEncoder(config)
|
||||
self.to_logits = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
self.post_init()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -28,7 +28,7 @@ else:
|
||||
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
@@ -44,7 +44,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .pipeline_ltx2_condition import LTX2ConditionPipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
@@ -9,6 +11,79 @@ from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
def per_layer_masked_mean_norm(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
):
|
||||
"""
|
||||
Performs per-batch per-layer normalization using a masked mean and range on per-layer text encoder hidden_states.
|
||||
Respects the padding of the hidden states.
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
|
||||
def per_token_rms_norm(text_encoder_hidden_states: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
|
||||
variance = torch.mean(text_encoder_hidden_states**2, dim=2, keepdim=True)
|
||||
norm_text_encoder_hidden_states = text_encoder_hidden_states * torch.rsqrt(variance + eps)
|
||||
return norm_text_encoder_hidden_states
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
@@ -106,6 +181,7 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
apply_gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -115,8 +191,9 @@ class LTX2TransformerBlock1d(nn.Module):
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=apply_gated_attention,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
@@ -160,6 +237,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
gated_attention: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
@@ -188,6 +266,7 @@ class LTX2ConnectorTransformer1d(nn.Module):
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
apply_gated_attention=gated_attention,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
@@ -260,24 +339,36 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
caption_channels: int = 3840, # default Gemma-3-12B text encoder hidden_size
|
||||
text_proj_in_factor: int = 49, # num_layers + 1 for embedding layer = 48 + 1 for Gemma-3-12B
|
||||
video_connector_num_attention_heads: int = 30,
|
||||
video_connector_attention_head_dim: int = 128,
|
||||
video_connector_num_layers: int = 2,
|
||||
video_connector_num_learnable_registers: int | None = 128,
|
||||
video_gated_attn: bool = False,
|
||||
audio_connector_num_attention_heads: int = 30,
|
||||
audio_connector_attention_head_dim: int = 128,
|
||||
audio_connector_num_layers: int = 2,
|
||||
audio_connector_num_learnable_registers: int | None = 128,
|
||||
audio_gated_attn: bool = False,
|
||||
connector_rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
per_modality_projections: bool = False,
|
||||
video_hidden_dim: int = 4096,
|
||||
audio_hidden_dim: int = 2048,
|
||||
proj_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
text_encoder_dim = caption_channels * text_proj_in_factor
|
||||
if per_modality_projections:
|
||||
self.video_text_proj_in = nn.Linear(text_encoder_dim, video_hidden_dim, bias=proj_bias)
|
||||
self.audio_text_proj_in = nn.Linear(text_encoder_dim, audio_hidden_dim, bias=proj_bias)
|
||||
else:
|
||||
self.text_proj_in = nn.Linear(text_encoder_dim, caption_channels, bias=proj_bias)
|
||||
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
@@ -288,6 +379,7 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=video_gated_attn,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
@@ -299,26 +391,86 @@ class LTX2TextConnectors(ModelMixin, PeftAdapterMixin, ConfigMixin):
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
gated_attention=audio_gated_attn,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
self,
|
||||
text_encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Given per-layer text encoder hidden_states, extracts features and runs per-modality connectors to get text
|
||||
embeddings for the LTX-2.X DiT models.
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
Args:
|
||||
text_encoder_hidden_states (`torch.Tensor`)):
|
||||
Per-layer text encoder hidden_states. Can either be 4D with shape `(batch_size, seq_len,
|
||||
caption_channels, text_proj_in_factor) or 3D with the last two dimensions flattened.
|
||||
attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`):
|
||||
Multiplicative binary attention mask where 1s indicate unmasked positions and 0s indicate masked
|
||||
positions.
|
||||
padding_side (`str`, *optional*, defaults to `"left"`):
|
||||
The padding side used by the text encoder's text encoder (either `"left"` or `"right"`). Defaults to
|
||||
`"left"` as this is what the default Gemma3-12B text encoder uses. Only used if
|
||||
`per_modality_projections` is `False` (LTX-2.0 models).
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scale factor for masked mean/range normalization. Only used if `per_modality_projections` is `False`
|
||||
(LTX-2.0 models).
|
||||
"""
|
||||
if text_encoder_hidden_states.ndim == 3:
|
||||
# Ensure shape is [batch_size, seq_len, caption_channels, text_proj_in_factor]
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.unflatten(2, (self.config.caption_channels, -1))
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
if self.config.per_modality_projections:
|
||||
# LTX-2.3
|
||||
norm_text_encoder_hidden_states = per_token_rms_norm(text_encoder_hidden_states)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.flatten(2, 3)
|
||||
bool_mask = attention_mask.bool().unsqueeze(-1)
|
||||
norm_text_encoder_hidden_states = torch.where(
|
||||
bool_mask, norm_text_encoder_hidden_states, torch.zeros_like(norm_text_encoder_hidden_states)
|
||||
)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
# Rescale norms with respect to video and audio dims for feature extractors
|
||||
video_scale_factor = math.sqrt(self.config.video_hidden_dim / self.config.caption_channels)
|
||||
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
|
||||
audio_scale_factor = math.sqrt(self.config.audio_hidden_dim / self.config.caption_channels)
|
||||
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
# Per-Modality Feature extractors
|
||||
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
|
||||
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
|
||||
else:
|
||||
# LTX-2.0
|
||||
sequence_lengths = attention_mask.sum(dim=-1)
|
||||
norm_text_encoder_hidden_states = per_layer_masked_mean_norm(
|
||||
text_hidden_states=text_encoder_hidden_states,
|
||||
sequence_lengths=sequence_lengths,
|
||||
device=text_encoder_hidden_states.device,
|
||||
padding_side=padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
|
||||
text_emb_proj = self.text_proj_in(norm_text_encoder_hidden_states)
|
||||
video_text_emb_proj = text_emb_proj
|
||||
audio_text_emb_proj = text_emb_proj
|
||||
|
||||
# Convert to additive attention mask for connectors
|
||||
text_dtype = video_text_emb_proj.dtype
|
||||
attention_mask = (attention_mask.to(torch.int64) - 1).to(text_dtype)
|
||||
attention_mask = attention_mask.reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
add_attn_mask = attention_mask * torch.finfo(text_dtype).max
|
||||
|
||||
video_text_embedding, video_attn_mask = self.video_connector(video_text_emb_proj, add_attn_mask)
|
||||
|
||||
# Convert video attn mask to binary (multiplicative) mask and mask video text embedding
|
||||
binary_attn_mask = (video_attn_mask < 1e-6).to(torch.int64)
|
||||
binary_attn_mask = binary_attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * binary_attn_mask
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(audio_text_emb_proj, add_attn_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, binary_attn_mask.squeeze(-1)
|
||||
|
||||
@@ -145,7 +145,7 @@ def encode_video(
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = (video * 255).astype("uint8")
|
||||
else:
|
||||
logger.warning(
|
||||
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
|
||||
|
||||
@@ -195,7 +195,8 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: float | None = 2.0,
|
||||
rational_spatial_scale: float = 2.0,
|
||||
use_rational_resampler: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -220,7 +221,7 @@ class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_spatial_scale is not None:
|
||||
if use_rational_resampler:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import FromSingleFileMixin, LTX2LoraLoaderMixin
|
||||
@@ -31,7 +31,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -209,7 +209,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_optional_components = ["processor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -221,7 +221,8 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -234,6 +235,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -268,73 +270,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: str | list[str],
|
||||
@@ -387,16 +322,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -494,6 +420,50 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generator: torch.Generator | None = None,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": f"user prompt: {prompt}"},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=None, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
if generator is not None:
|
||||
# Overwrite seed to generator's initial seed
|
||||
seed = generator.initial_seed()
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -504,6 +474,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -547,6 +520,12 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
# Unpacked latents of shape are [B, C, F, H, W] are patched into tokens of shape [B, C, F // p_t, p_t, H // p, p, W // p, p].
|
||||
@@ -734,7 +713,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
@@ -749,6 +727,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents
|
||||
|
||||
def convert_velocity_to_x0(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
|
||||
return sample_x0
|
||||
|
||||
def convert_x0_to_velocity(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
|
||||
return sample_v
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -757,9 +753,41 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -791,7 +819,14 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -803,6 +838,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -841,13 +881,47 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -878,6 +952,24 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -910,6 +1002,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -920,10 +1017,21 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -939,6 +1047,17 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generator=generator,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -960,9 +1079,11 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -984,7 +1105,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
# video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
@@ -1041,7 +1162,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.95),
|
||||
@@ -1069,11 +1190,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1111,6 +1227,7 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
encoder_hidden_states=connector_prompt_embeds,
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
@@ -1120,7 +1237,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1128,24 +1248,155 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_text, i, self.scheduler
|
||||
)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
|
||||
)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_stg, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_modality, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
|
||||
)
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# Convert back to velocity for scheduler
|
||||
noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
|
||||
@@ -1177,9 +1428,6 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
@@ -1187,6 +1435,9 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
audio = audio_latents
|
||||
else:
|
||||
@@ -1209,6 +1460,10 @@ class LTX2Pipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoaderMixin):
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
@@ -33,7 +33,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -254,7 +254,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -300,74 +300,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -421,16 +353,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -541,6 +464,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=None,
|
||||
latents=None,
|
||||
audio_latents=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -597,6 +523,12 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
f" using the `_unpack_audio_latents` method)."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -984,6 +916,24 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents
|
||||
|
||||
def convert_velocity_to_x0(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
|
||||
return sample_x0
|
||||
|
||||
def convert_x0_to_velocity(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
|
||||
return sample_v
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -992,9 +942,41 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -1027,7 +1009,14 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[float] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float | None = None,
|
||||
num_videos_per_prompt: int | None = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -1039,6 +1028,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -1079,13 +1069,47 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising. If not set, will be inferred from the
|
||||
@@ -1117,6 +1141,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -1149,6 +1177,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -1161,10 +1194,21 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
latents=latents,
|
||||
audio_latents=audio_latents,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -1208,9 +1252,11 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1222,7 +1268,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
"Got latents of shape [batch_size, latent_dim, latent_frames, latent_height, latent_width], `latent_num_frames`, `latent_height`, `latent_width` will be inferred."
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape # [B, C, F, H, W]
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
# video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents, conditioning_mask, clean_latents = self.prepare_latents(
|
||||
@@ -1272,7 +1318,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.95),
|
||||
@@ -1301,11 +1347,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1344,6 +1385,7 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
@@ -1353,7 +1395,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1361,41 +1406,172 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_text, i, self.scheduler
|
||||
)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
|
||||
)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_stg, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_modality, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
|
||||
)
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# NOTE: use only the first chunk of conditioning mask in case it is duplicated for CFG
|
||||
bsz = noise_pred_video.size(0)
|
||||
sigma = self.scheduler.sigmas[i]
|
||||
# Convert the noise_pred_video velocity model prediction into a sample (x0) prediction
|
||||
denoised_sample = latents - noise_pred_video * sigma
|
||||
# Apply the (packed) conditioning mask to the denoised (x0) sample and clean conditioning. The
|
||||
# conditioning mask contains conditioning strengths from 0 (always use denoised sample) to 1 (always
|
||||
# use conditions), with intermediate values specifying how strongly to follow the conditions.
|
||||
# NOTE: this operation should be applied in sample (x0) space and not velocity space (which is the
|
||||
# space the denoising model outputs are in)
|
||||
denoised_sample_cond = (
|
||||
denoised_sample * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
|
||||
noise_pred_video * (1 - conditioning_mask[:bsz]) + clean_latents.float() * conditioning_mask[:bsz]
|
||||
).to(noise_pred_video.dtype)
|
||||
|
||||
# Convert the denoised (x0) sample back to a velocity for the scheduler
|
||||
denoised_latents_cond = ((latents - denoised_sample_cond) / sigma).to(noise_pred_video.dtype)
|
||||
noise_pred_video = self.convert_x0_to_velocity(latents, denoised_sample_cond, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(denoised_latents_cond, t, latents, return_dict=False)[0]
|
||||
latents = self.scheduler.step(noise_pred_video, t, latents, return_dict=False)[0]
|
||||
|
||||
# NOTE: for now duplicate scheduler for audio latents in case self.scheduler sets internal state in
|
||||
# the step method (such as _step_index)
|
||||
@@ -1425,9 +1601,6 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
@@ -1435,6 +1608,9 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
audio = audio_latents
|
||||
else:
|
||||
@@ -1457,6 +1633,10 @@ class LTX2ConditionPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraLoad
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import Any, Callable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
from transformers import Gemma3ForConditionalGeneration, Gemma3Processor, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
@@ -32,7 +32,7 @@ from ...video_processor import VideoProcessor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .pipeline_output import LTX2PipelineOutput
|
||||
from .vocoder import LTX2Vocoder
|
||||
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
@@ -212,7 +212,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->connectors->transformer->vae->audio_vae->vocoder"
|
||||
_optional_components = []
|
||||
_optional_components = ["processor"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
@@ -224,7 +224,8 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
connectors: LTX2TextConnectors,
|
||||
transformer: LTX2VideoTransformer3DModel,
|
||||
vocoder: LTX2Vocoder,
|
||||
vocoder: LTX2Vocoder | LTX2VocoderWithBWE,
|
||||
processor: Gemma3Processor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -237,6 +238,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
scheduler=scheduler,
|
||||
processor=processor,
|
||||
)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
@@ -271,74 +273,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self.tokenizer.model_max_length if getattr(self, "tokenizer", None) is not None else 1024
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_text_embeds
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Packs and normalizes text encoder hidden states, respecting padding. Normalization is performed per-batch and
|
||||
per-layer in a masked fashion (only over non-padded positions).
|
||||
|
||||
Args:
|
||||
text_hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim, num_layers)`):
|
||||
Per-layer hidden_states from a text encoder (e.g. `Gemma3ForConditionalGeneration`).
|
||||
sequence_lengths (`torch.Tensor of shape `(batch_size,)`):
|
||||
The number of valid (non-padded) tokens for each batch instance.
|
||||
device: (`str` or `torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
padding_side: (`str`, *optional*, defaults to `"left"`):
|
||||
Whether the text tokenizer performs padding on the `"left"` or `"right"`.
|
||||
scale_factor (`int`, *optional*, defaults to `8`):
|
||||
Scaling factor to multiply the normalized hidden states by.
|
||||
eps (`float`, *optional*, defaults to `1e-6`):
|
||||
A small positive value for numerical stability when performing normalization.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(batch_size, seq_len, hidden_dim * num_layers)`:
|
||||
Normed and flattened text encoder hidden states.
|
||||
"""
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
# Create padding mask
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
# For right padding, valid tokens are from 0 to sequence_length-1
|
||||
mask = token_indices < sequence_lengths[:, None] # [batch_size, seq_len]
|
||||
elif padding_side == "left":
|
||||
# For left padding, valid tokens are from (T - sequence_length) to T-1
|
||||
start_indices = seq_len - sequence_lengths[:, None] # [batch_size, 1]
|
||||
mask = token_indices >= start_indices # [B, T]
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None] # [batch_size, seq_len] --> [batch_size, seq_len, 1, 1]
|
||||
|
||||
# Compute masked mean over non-padding positions of shape (batch_size, 1, 1, seq_len)
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
# Compute min/max over non-padding positions of shape (batch_size, 1, 1 seq_len)
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
# Normalization
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
# Pack the hidden states to a 3D tensor (batch_size, seq_len, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
@@ -392,16 +326,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
sequence_lengths = prompt_attention_mask.sum(dim=-1)
|
||||
|
||||
prompt_embeds = self._pack_text_embeds(
|
||||
text_encoder_hidden_states,
|
||||
sequence_lengths,
|
||||
device=device,
|
||||
padding_side=self.tokenizer.padding_side,
|
||||
scale_factor=scale_factor,
|
||||
)
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype)
|
||||
prompt_embeds = text_encoder_hidden_states.flatten(2, 3).to(dtype=dtype) # Pack to 3D
|
||||
|
||||
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
@@ -500,6 +425,57 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def enhance_prompt(
|
||||
self,
|
||||
image: PipelineImageInput,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
max_new_tokens: int = 512,
|
||||
seed: int = 10,
|
||||
generator: torch.Generator | None = None,
|
||||
generation_kwargs: dict[str, Any] | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
):
|
||||
"""
|
||||
Enhances the supplied `prompt` by generating a new prompt using the current text encoder (default is a
|
||||
`transformers.Gemma3ForConditionalGeneration` model) from it and a system prompt.
|
||||
"""
|
||||
device = device or self._execution_device
|
||||
if generation_kwargs is None:
|
||||
# Set to default generation kwargs
|
||||
generation_kwargs = {"do_sample": True, "temperature": 0.7}
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image"},
|
||||
{"type": "text", "text": f"User Raw Input Prompt: {prompt}."},
|
||||
],
|
||||
},
|
||||
]
|
||||
template = self.processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
||||
model_inputs = self.processor(text=template, images=image, return_tensors="pt").to(device)
|
||||
self.text_encoder.to(device)
|
||||
|
||||
# `transformers.GenerationMixin.generate` does not support using a `torch.Generator` to control randomness,
|
||||
# so manually apply a seed for reproducible generation.
|
||||
if generator is not None:
|
||||
# Overwrite seed to generator's initial seed
|
||||
seed = generator.initial_seed()
|
||||
torch.manual_seed(seed)
|
||||
generated_sequences = self.text_encoder.generate(
|
||||
**model_inputs,
|
||||
max_new_tokens=max_new_tokens,
|
||||
**generation_kwargs,
|
||||
) # tensor of shape [batch_size, seq_len]
|
||||
|
||||
generated_ids = [seq[len(model_inputs.input_ids[i]) :] for i, seq in enumerate(generated_sequences)]
|
||||
enhanced_prompt = self.processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
return enhanced_prompt
|
||||
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
@@ -511,6 +487,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
stg_scale=None,
|
||||
audio_stg_scale=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -554,6 +533,12 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
if ((stg_scale > 0.0) or (audio_stg_scale > 0.0)) and not spatio_temporal_guidance_blocks:
|
||||
raise ValueError(
|
||||
"Spatio-Temporal Guidance (STG) is specified but no STG blocks are supplied. Please supply a list of"
|
||||
"block indices at which to apply STG in `spatio_temporal_guidance_blocks`"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._pack_latents
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
@@ -788,7 +773,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
latents = self._create_noised_state(latents, noise_scale, generator)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# TODO: confirm whether this logic is correct
|
||||
latent_mel_bins = num_mel_bins // self.audio_vae_mel_compression_ratio
|
||||
|
||||
shape = (batch_size, num_channels_latents, audio_latent_length, latent_mel_bins)
|
||||
@@ -803,6 +787,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
latents = self._pack_audio_latents(latents)
|
||||
return latents
|
||||
|
||||
def convert_velocity_to_x0(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_x0 = sample - denoised_output * scheduler.sigmas[step_idx]
|
||||
return sample_x0
|
||||
|
||||
def convert_x0_to_velocity(
|
||||
self, sample: torch.Tensor, denoised_output: torch.Tensor, step_idx: int, scheduler: Any | None = None
|
||||
) -> torch.Tensor:
|
||||
if scheduler is None:
|
||||
scheduler = self.scheduler
|
||||
|
||||
sample_v = (sample - denoised_output) / scheduler.sigmas[step_idx]
|
||||
return sample_v
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -811,9 +813,41 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
def guidance_rescale(self):
|
||||
return self._guidance_rescale
|
||||
|
||||
@property
|
||||
def stg_scale(self):
|
||||
return self._stg_scale
|
||||
|
||||
@property
|
||||
def modality_scale(self):
|
||||
return self._modality_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_scale(self):
|
||||
return self._audio_guidance_scale
|
||||
|
||||
@property
|
||||
def audio_guidance_rescale(self):
|
||||
return self._audio_guidance_rescale
|
||||
|
||||
@property
|
||||
def audio_stg_scale(self):
|
||||
return self._audio_stg_scale
|
||||
|
||||
@property
|
||||
def audio_modality_scale(self):
|
||||
return self._audio_modality_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
return (self._guidance_scale > 1.0) or (self._audio_guidance_scale > 1.0)
|
||||
|
||||
@property
|
||||
def do_spatio_temporal_guidance(self):
|
||||
return (self._stg_scale > 0.0) or (self._audio_stg_scale > 0.0)
|
||||
|
||||
@property
|
||||
def do_modality_isolation_guidance(self):
|
||||
return (self._modality_scale > 1.0) or (self._audio_modality_scale > 1.0)
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
@@ -846,7 +880,14 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
sigmas: list[float] | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
guidance_scale: float = 4.0,
|
||||
stg_scale: float = 0.0,
|
||||
modality_scale: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
audio_guidance_scale: float | None = None,
|
||||
audio_stg_scale: float | None = None,
|
||||
audio_modality_scale: float | None = None,
|
||||
audio_guidance_rescale: float | None = None,
|
||||
spatio_temporal_guidance_blocks: list[int] | None = None,
|
||||
noise_scale: float = 0.0,
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
@@ -858,6 +899,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_attention_mask: torch.Tensor | None = None,
|
||||
decode_timestep: float | list[float] = 0.0,
|
||||
decode_noise_scale: float | list[float] | None = None,
|
||||
use_cross_timestep: bool = False,
|
||||
system_prompt: str | None = None,
|
||||
prompt_max_new_tokens: int = 512,
|
||||
prompt_enhancement_kwargs: dict[str, Any] | None = None,
|
||||
prompt_enhancement_seed: int = 10,
|
||||
output_type: str = "pil",
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
@@ -898,13 +944,47 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
the text `prompt`, usually at the expense of lower image quality. Used for the video modality (there is
|
||||
a separate value `audio_guidance_scale` for the audio modality).
|
||||
stg_scale (`float`, *optional*, defaults to `0.0`):
|
||||
Video guidance scale for Spatio-Temporal Guidance (STG), proposed in [Spatiotemporal Skip Guidance for
|
||||
Enhanced Video Diffusion Sampling](https://arxiv.org/abs/2411.18664). STG uses a CFG-like estimate
|
||||
where we move the sample away from a weak sample from a perturbed version of the denoising model.
|
||||
Enabling STG will result in an additional denoising model forward pass; the default value of `0.0`
|
||||
means that STG is disabled.
|
||||
modality_scale (`float`, *optional*, defaults to `1.0`):
|
||||
Video guidance scale for LTX-2.X modality isolation guidance, where we move the sample away from a
|
||||
weaker sample generated by the denoising model withy cross-modality (audio-to-video and video-to-audio)
|
||||
cross attention disabled using a CFG-like estimate. Enabling modality guidance will result in an
|
||||
additional denoising model forward pass; the default value of `1.0` means that modality guidance is
|
||||
disabled.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891) `guidance_scale` is defined as `φ` in equation 16. of
|
||||
[Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891). Guidance rescale factor should fix overexposure when
|
||||
using zero terminal SNR.
|
||||
using zero terminal SNR. Used for the video modality.
|
||||
audio_guidance_scale (`float`, *optional* defaults to `None`):
|
||||
Audio guidance scale for CFG with respect to the negative prompt. The CFG update rule is the same for
|
||||
video and audio, but they can use different values for the guidance scale. The LTX-2.X authors suggest
|
||||
that the `audio_guidance_scale` should be higher relative to the video `guidance_scale` (e.g. for
|
||||
LTX-2.3 they suggest 3.0 for video and 7.0 for audio). If `None`, defaults to the video value
|
||||
`guidance_scale`.
|
||||
audio_stg_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for STG. As with CFG, the STG update rule is otherwise the same for video and
|
||||
audio. For LTX-2.3, a value of 1.0 is suggested for both video and audio. If `None`, defaults to the
|
||||
video value `stg_scale`.
|
||||
audio_modality_scale (`float`, *optional*, defaults to `None`):
|
||||
Audio guidance scale for LTX-2.X modality isolation guidance. As with CFG, the modality guidance rule
|
||||
is otherwise the same for video and audio. For LTX-2.3, a value of 3.0 is suggested for both video and
|
||||
audio. If `None`, defaults to the video value `modality_scale`.
|
||||
audio_guidance_rescale (`float`, *optional*, defaults to `None`):
|
||||
A separate guidance rescale factor for the audio modality. If `None`, defaults to the video value
|
||||
`guidance_rescale`.
|
||||
spatio_temporal_guidance_blocks (`list[int]`, *optional*, defaults to `None`):
|
||||
The zero-indexed transformer block indices at which to apply STG. Must be supplied if STG is used
|
||||
(`stg_scale` or `audio_stg_scale` is greater than `0`). A value of `[29]` is recommended for LTX-2.0
|
||||
and `[28]` is recommended for LTX-2.3.
|
||||
noise_scale (`float`, *optional*, defaults to `0.0`):
|
||||
The interpolation factor between random noise and denoised latents at each timestep. Applying noise to
|
||||
the `latents` and `audio_latents` before continue denoising.
|
||||
@@ -935,6 +1015,24 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
use_cross_timestep (`bool` *optional*, defaults to `False`):
|
||||
Whether to use the cross modality (audio is the cross modality of video, and vice versa) sigma when
|
||||
calculating the cross attention modulation parameters. `True` is the newer (e.g. LTX-2.3) behavior;
|
||||
`False` is the legacy LTX-2.0 behavior.
|
||||
system_prompt (`str`, *optional*, defaults to `None`):
|
||||
Optional system prompt to use for prompt enhancement. The system prompt will be used by the current
|
||||
text encoder (by default, a `Gemma3ForConditionalGeneration` model) to generate an enhanced prompt from
|
||||
the original `prompt` to condition generation. If not supplied, prompt enhancement will not be
|
||||
performed.
|
||||
prompt_max_new_tokens (`int`, *optional*, defaults to `512`):
|
||||
The maximum number of new tokens to generate when performing prompt enhancement.
|
||||
prompt_enhancement_kwargs (`dict[str, Any]`, *optional*, defaults to `None`):
|
||||
Keyword arguments for `self.text_encoder.generate`. If not supplied, default arguments of
|
||||
`do_sample=True` and `temperature=0.7` will be used. See
|
||||
https://huggingface.co/docs/transformers/main/en/main_classes/text_generation#transformers.GenerationMixin.generate
|
||||
for more details.
|
||||
prompt_enhancement_seed (`int`, *optional*, default to `10`):
|
||||
Random seed for any random operations during prompt enhancement.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
@@ -967,6 +1065,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
audio_guidance_scale = audio_guidance_scale or guidance_scale
|
||||
audio_stg_scale = audio_stg_scale or stg_scale
|
||||
audio_modality_scale = audio_modality_scale or modality_scale
|
||||
audio_guidance_rescale = audio_guidance_rescale or guidance_rescale
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
@@ -977,10 +1080,21 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
stg_scale=stg_scale,
|
||||
audio_stg_scale=audio_stg_scale,
|
||||
)
|
||||
|
||||
# Per-modality guidance scales (video, audio)
|
||||
self._guidance_scale = guidance_scale
|
||||
self._stg_scale = stg_scale
|
||||
self._modality_scale = modality_scale
|
||||
self._guidance_rescale = guidance_rescale
|
||||
self._audio_guidance_scale = audio_guidance_scale
|
||||
self._audio_stg_scale = audio_stg_scale
|
||||
self._audio_modality_scale = audio_modality_scale
|
||||
self._audio_guidance_rescale = audio_guidance_rescale
|
||||
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
self._current_timestep = None
|
||||
@@ -996,6 +1110,18 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Prepare text embeddings
|
||||
if system_prompt is not None and prompt is not None:
|
||||
prompt = self.enhance_prompt(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
max_new_tokens=prompt_max_new_tokens,
|
||||
seed=prompt_enhancement_seed,
|
||||
generator=generator,
|
||||
generation_kwargs=prompt_enhancement_kwargs,
|
||||
device=device,
|
||||
)
|
||||
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
@@ -1017,9 +1143,11 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
|
||||
additive_attention_mask = (1 - prompt_attention_mask.to(prompt_embeds.dtype)) * -1000000.0
|
||||
tokenizer_padding_side = "left" # Padding side for default Gemma3-12B text encoder
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
tokenizer_padding_side = getattr(self.tokenizer, "padding_side", "left")
|
||||
connector_prompt_embeds, connector_audio_prompt_embeds, connector_attention_mask = self.connectors(
|
||||
prompt_embeds, additive_attention_mask, additive_mask=True
|
||||
prompt_embeds, prompt_attention_mask, padding_side=tokenizer_padding_side
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -1041,7 +1169,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
raise ValueError(
|
||||
f"Provided `latents` tensor has shape {latents.shape}, but the expected shape is either [batch_size, seq_len, num_features] or [batch_size, latent_dim, latent_frames, latent_height, latent_width]."
|
||||
)
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
# video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if latents is None:
|
||||
image = self.video_processor.preprocess(image, height=height, width=width)
|
||||
@@ -1105,7 +1233,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
# 5. Prepare timesteps
|
||||
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
|
||||
mu = calculate_shift(
|
||||
video_sequence_length,
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_image_seq_len", 1024),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.95),
|
||||
@@ -1134,11 +1262,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 6. Prepare micro-conditions
|
||||
rope_interpolation_scale = (
|
||||
self.vae_temporal_compression_ratio / frame_rate,
|
||||
self.vae_spatial_compression_ratio,
|
||||
self.vae_spatial_compression_ratio,
|
||||
)
|
||||
# Pre-compute video and audio positional ids as they will be the same at each step of the denoising loop
|
||||
video_coords = self.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0], latent_num_frames, latent_height, latent_width, latents.device, fps=frame_rate
|
||||
@@ -1177,6 +1300,7 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_encoder_hidden_states=connector_audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=connector_attention_mask,
|
||||
audio_encoder_attention_mask=connector_attention_mask,
|
||||
num_frames=latent_num_frames,
|
||||
@@ -1186,7 +1310,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_coords,
|
||||
audio_coords=audio_coords,
|
||||
# rope_interpolation_scale=rope_interpolation_scale,
|
||||
isolate_modalities=False,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
@@ -1194,24 +1321,154 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
noise_pred_audio = noise_pred_audio.float()
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_pred_video_uncond, noise_pred_video_text = noise_pred_video.chunk(2)
|
||||
noise_pred_video = noise_pred_video_uncond + self.guidance_scale * (
|
||||
noise_pred_video_text - noise_pred_video_uncond
|
||||
noise_pred_video_uncond_text, noise_pred_video = noise_pred_video.chunk(2)
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_video_uncond_text = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_text, i, self.scheduler
|
||||
)
|
||||
# Use delta formulation as it works more nicely with multiple guidance terms
|
||||
video_cfg_delta = (self.guidance_scale - 1) * (noise_pred_video - noise_pred_video_uncond_text)
|
||||
|
||||
noise_pred_audio_uncond_text, noise_pred_audio = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
noise_pred_audio_uncond_text = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_text, i, audio_scheduler
|
||||
)
|
||||
audio_cfg_delta = (self.audio_guidance_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_text
|
||||
)
|
||||
|
||||
noise_pred_audio_uncond, noise_pred_audio_text = noise_pred_audio.chunk(2)
|
||||
noise_pred_audio = noise_pred_audio_uncond + self.guidance_scale * (
|
||||
noise_pred_audio_text - noise_pred_audio_uncond
|
||||
# Get positive values from merged CFG inputs in case we need to do other DiT forward passes
|
||||
if self.do_spatio_temporal_guidance or self.do_modality_isolation_guidance:
|
||||
if i == 0:
|
||||
# Only split values that remain constant throughout the loop once
|
||||
video_prompt_embeds = connector_prompt_embeds.chunk(2, dim=0)[1]
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds.chunk(2, dim=0)[1]
|
||||
prompt_attn_mask = connector_attention_mask.chunk(2, dim=0)[1]
|
||||
|
||||
video_pos_ids = video_coords.chunk(2, dim=0)[0]
|
||||
audio_pos_ids = audio_coords.chunk(2, dim=0)[0]
|
||||
|
||||
# Split values that vary each denoising loop iteration
|
||||
timestep = timestep.chunk(2, dim=0)[0]
|
||||
video_timestep = video_timestep.chunk(2, dim=0)[0]
|
||||
else:
|
||||
video_cfg_delta = audio_cfg_delta = 0
|
||||
|
||||
video_prompt_embeds = connector_prompt_embeds
|
||||
audio_prompt_embeds = connector_audio_prompt_embeds
|
||||
prompt_attn_mask = connector_attention_mask
|
||||
|
||||
video_pos_ids = video_coords
|
||||
audio_pos_ids = audio_coords
|
||||
|
||||
noise_pred_video = self.convert_velocity_to_x0(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_velocity_to_x0(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
if self.do_spatio_temporal_guidance:
|
||||
with self.transformer.cache_context("uncond_stg"):
|
||||
noise_pred_video_uncond_stg, noise_pred_audio_uncond_stg = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
isolate_modalities=False,
|
||||
# Use STG at given blocks to perturb model
|
||||
spatio_temporal_guidance_blocks=spatio_temporal_guidance_blocks,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_stg = noise_pred_video_uncond_stg.float()
|
||||
noise_pred_audio_uncond_stg = noise_pred_audio_uncond_stg.float()
|
||||
noise_pred_video_uncond_stg = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_stg, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_stg = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_stg, i, audio_scheduler
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0:
|
||||
# Based on 3.4. in https://huggingface.co/papers/2305.08891
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video, noise_pred_video_text, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio, noise_pred_audio_text, guidance_rescale=self.guidance_rescale
|
||||
video_stg_delta = self.stg_scale * (noise_pred_video - noise_pred_video_uncond_stg)
|
||||
audio_stg_delta = self.audio_stg_scale * (noise_pred_audio - noise_pred_audio_uncond_stg)
|
||||
else:
|
||||
video_stg_delta = audio_stg_delta = 0
|
||||
|
||||
if self.do_modality_isolation_guidance:
|
||||
with self.transformer.cache_context("uncond_modality"):
|
||||
noise_pred_video_uncond_modality, noise_pred_audio_uncond_modality = self.transformer(
|
||||
hidden_states=latents.to(dtype=prompt_embeds.dtype),
|
||||
audio_hidden_states=audio_latents.to(dtype=prompt_embeds.dtype),
|
||||
encoder_hidden_states=video_prompt_embeds,
|
||||
audio_encoder_hidden_states=audio_prompt_embeds,
|
||||
timestep=video_timestep,
|
||||
audio_timestep=timestep,
|
||||
sigma=timestep, # Used by LTX-2.3
|
||||
encoder_attention_mask=prompt_attn_mask,
|
||||
audio_encoder_attention_mask=prompt_attn_mask,
|
||||
num_frames=latent_num_frames,
|
||||
height=latent_height,
|
||||
width=latent_width,
|
||||
fps=frame_rate,
|
||||
audio_num_frames=audio_num_frames,
|
||||
video_coords=video_pos_ids,
|
||||
audio_coords=audio_pos_ids,
|
||||
# Turn off A2V and V2A cross attn to isolate video and audio modalities
|
||||
isolate_modalities=True,
|
||||
spatio_temporal_guidance_blocks=None,
|
||||
perturbation_mask=None,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)
|
||||
noise_pred_video_uncond_modality = noise_pred_video_uncond_modality.float()
|
||||
noise_pred_audio_uncond_modality = noise_pred_audio_uncond_modality.float()
|
||||
noise_pred_video_uncond_modality = self.convert_velocity_to_x0(
|
||||
latents, noise_pred_video_uncond_modality, i, self.scheduler
|
||||
)
|
||||
noise_pred_audio_uncond_modality = self.convert_velocity_to_x0(
|
||||
audio_latents, noise_pred_audio_uncond_modality, i, audio_scheduler
|
||||
)
|
||||
|
||||
video_modality_delta = (self.modality_scale - 1) * (
|
||||
noise_pred_video - noise_pred_video_uncond_modality
|
||||
)
|
||||
audio_modality_delta = (self.audio_modality_scale - 1) * (
|
||||
noise_pred_audio - noise_pred_audio_uncond_modality
|
||||
)
|
||||
else:
|
||||
video_modality_delta = audio_modality_delta = 0
|
||||
|
||||
# Now apply all guidance terms
|
||||
noise_pred_video_g = noise_pred_video + video_cfg_delta + video_stg_delta + video_modality_delta
|
||||
noise_pred_audio_g = noise_pred_audio + audio_cfg_delta + audio_stg_delta + audio_modality_delta
|
||||
|
||||
# Apply LTX-2.X guidance rescaling
|
||||
if self.guidance_rescale > 0:
|
||||
noise_pred_video = rescale_noise_cfg(
|
||||
noise_pred_video_g, noise_pred_video, guidance_rescale=self.guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_video = noise_pred_video_g
|
||||
|
||||
if self.audio_guidance_rescale > 0:
|
||||
noise_pred_audio = rescale_noise_cfg(
|
||||
noise_pred_audio_g, noise_pred_audio, guidance_rescale=self.audio_guidance_rescale
|
||||
)
|
||||
else:
|
||||
noise_pred_audio = noise_pred_audio_g
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
noise_pred_video = self._unpack_latents(
|
||||
@@ -1231,6 +1488,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
|
||||
# Convert back to velocity for scheduler
|
||||
noise_pred_video = self.convert_x0_to_velocity(latents, noise_pred_video, i, self.scheduler)
|
||||
noise_pred_audio = self.convert_x0_to_velocity(audio_latents, noise_pred_audio, i, audio_scheduler)
|
||||
|
||||
noise_pred_video = noise_pred_video[:, :, 1:]
|
||||
noise_latents = latents[:, :, 1:]
|
||||
pred_latents = self.scheduler.step(noise_pred_video, t, noise_latents, return_dict=False)[0]
|
||||
@@ -1268,9 +1529,6 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
self.transformer_spatial_patch_size,
|
||||
self.transformer_temporal_patch_size,
|
||||
)
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
audio_latents = self._denormalize_audio_latents(
|
||||
audio_latents, self.audio_vae.latents_mean, self.audio_vae.latents_std
|
||||
@@ -1278,6 +1536,9 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
audio_latents = self._unpack_audio_latents(audio_latents, audio_num_frames, num_mel_bins=latent_mel_bins)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
audio = audio_latents
|
||||
else:
|
||||
@@ -1300,6 +1561,10 @@ class LTX2ImageToVideoPipeline(DiffusionPipeline, FromSingleFileMixin, LTX2LoraL
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
latents = latents.to(self.vae.dtype)
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
@@ -8,6 +8,209 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
def kaiser_sinc_filter1d(cutoff: float, half_width: float, kernel_size: int) -> torch.Tensor:
|
||||
"""
|
||||
Creates a Kaiser sinc kernel for low-pass filtering.
|
||||
|
||||
Args:
|
||||
cutoff (`float`):
|
||||
Normalized frequency cutoff (relative to the sampling rate). Must be between 0 and 0.5 (the Nyquist
|
||||
frequency).
|
||||
half_width (`float`):
|
||||
Used to determine the Kaiser window's beta parameter.
|
||||
kernel_size:
|
||||
Size of the Kaiser window (and ultimately the Kaiser sinc kernel).
|
||||
|
||||
Returns:
|
||||
`torch.Tensor` of shape `(kernel_size,)`:
|
||||
The Kaiser sinc kernel.
|
||||
"""
|
||||
delta_f = 4 * half_width
|
||||
half_size = kernel_size // 2
|
||||
amplitude = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
|
||||
if amplitude > 50.0:
|
||||
beta = 0.1102 * (amplitude - 8.7)
|
||||
elif amplitude >= 21.0:
|
||||
beta = 0.5842 * (amplitude - 21) ** 0.4 + 0.07886 * (amplitude - 21.0)
|
||||
else:
|
||||
beta = 0.0
|
||||
|
||||
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
|
||||
|
||||
even = kernel_size % 2 == 0
|
||||
time = torch.arange(-half_size, half_size) + 0.5 if even else torch.arange(kernel_size) - half_size
|
||||
|
||||
if cutoff == 0.0:
|
||||
filter = torch.zeros_like(time)
|
||||
else:
|
||||
time = 2 * cutoff * time
|
||||
sinc = torch.where(
|
||||
time == 0,
|
||||
torch.ones_like(time),
|
||||
torch.sin(math.pi * time) / math.pi / time,
|
||||
)
|
||||
filter = 2 * cutoff * window * sinc
|
||||
filter = filter / filter.sum()
|
||||
return filter
|
||||
|
||||
|
||||
class DownSample1d(nn.Module):
|
||||
"""1D low-pass filter for antialias downsampling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
use_padding: bool = True,
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = kernel_size or int(6 * ratio // 2) * 2
|
||||
self.pad_left = self.kernel_size // 2 + (self.kernel_size % 2) - 1
|
||||
self.pad_right = self.kernel_size // 2
|
||||
self.use_padding = use_padding
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
cutoff = 0.5 / ratio
|
||||
half_width = 0.6 / ratio
|
||||
low_pass_filter = kaiser_sinc_filter1d(cutoff, half_width, self.kernel_size)
|
||||
self.register_buffer("filter", low_pass_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
if self.use_padding:
|
||||
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
|
||||
x_filtered = F.conv1d(x, self.filter.expand(num_channels, -1, -1), stride=self.ratio, groups=num_channels)
|
||||
return x_filtered
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
ratio: int = 2,
|
||||
kernel_size: int | None = None,
|
||||
window_type: str = "kaiser",
|
||||
padding_mode: str = "replicate",
|
||||
persistent: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.padding_mode = padding_mode
|
||||
|
||||
if window_type == "hann":
|
||||
rolloff = 0.99
|
||||
lowpass_filter_width = 6
|
||||
width = math.ceil(lowpass_filter_width / rolloff)
|
||||
self.kernel_size = 2 * width * ratio + 1
|
||||
self.pad = width
|
||||
self.pad_left = 2 * width * ratio
|
||||
self.pad_right = self.kernel_size - ratio
|
||||
|
||||
time_axis = (torch.arange(self.kernel_size) / ratio - width) * rolloff
|
||||
time_clamped = time_axis.clamp(-lowpass_filter_width, lowpass_filter_width)
|
||||
window = torch.cos(time_clamped * math.pi / lowpass_filter_width / 2) ** 2
|
||||
sinc_filter = (torch.sinc(time_axis) * window * rolloff / ratio).view(1, 1, -1)
|
||||
else:
|
||||
# Kaiser sinc filter is BigVGAN default
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.ratio + (self.kernel_size - self.ratio) // 2
|
||||
self.pad_right = self.pad * self.ratio + (self.kernel_size - self.ratio + 1) // 2
|
||||
|
||||
sinc_filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
kernel_size=self.kernel_size,
|
||||
)
|
||||
|
||||
self.register_buffer("filter", sinc_filter.view(1, 1, self.kernel_size), persistent=persistent)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# x expected shape: [batch_size, num_channels, hidden_dim]
|
||||
num_channels = x.shape[1]
|
||||
x = F.pad(x, (self.pad, self.pad), mode=self.padding_mode)
|
||||
low_pass_filter = self.filter.to(dtype=x.dtype, device=x.device).expand(num_channels, -1, -1)
|
||||
x = self.ratio * F.conv_transpose1d(x, low_pass_filter, stride=self.ratio, groups=num_channels)
|
||||
return x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
|
||||
class AntiAliasAct1d(nn.Module):
|
||||
"""
|
||||
Antialiasing activation for a 1D signal: upsamples, applies an activation (usually snakebeta), and then downsamples
|
||||
to avoid aliasing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
act_fn: str | nn.Module,
|
||||
ratio: int = 2,
|
||||
kernel_size: int = 12,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.upsample = UpSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
if isinstance(act_fn, str):
|
||||
if act_fn == "snakebeta":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
elif act_fn == "snake":
|
||||
act_fn = SnakeBeta(**kwargs)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(**kwargs)
|
||||
self.act = act_fn
|
||||
self.downsample = DownSample1d(ratio=ratio, kernel_size=kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.upsample(x)
|
||||
x = self.act(x)
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
|
||||
class SnakeBeta(nn.Module):
|
||||
"""
|
||||
Implements the Snake and SnakeBeta activations, which help with learning periodic patterns.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
alpha: float = 1.0,
|
||||
eps: float = 1e-9,
|
||||
trainable_params: bool = True,
|
||||
logscale: bool = True,
|
||||
use_beta: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.logscale = logscale
|
||||
self.use_beta = use_beta
|
||||
|
||||
self.alpha = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.alpha.requires_grad = trainable_params
|
||||
if use_beta:
|
||||
self.beta = nn.Parameter(torch.zeros(channels) if self.logscale else torch.ones(channels) * alpha)
|
||||
self.beta.requires_grad = trainable_params
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
|
||||
broadcast_shape = [1] * hidden_states.ndim
|
||||
broadcast_shape[channel_dim] = -1
|
||||
alpha = self.alpha.view(broadcast_shape)
|
||||
if self.use_beta:
|
||||
beta = self.beta.view(broadcast_shape)
|
||||
|
||||
if self.logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
if self.use_beta:
|
||||
beta = torch.exp(beta)
|
||||
|
||||
amplitude = beta if self.use_beta else alpha
|
||||
hidden_states = hidden_states + (1.0 / (amplitude + self.eps)) * torch.sin(hidden_states * alpha).pow(2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -15,12 +218,15 @@ class ResBlock(nn.Module):
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: tuple[int, ...] = (1, 3, 5),
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
@@ -28,6 +234,18 @@ class ResBlock(nn.Module):
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
self.acts1 = nn.ModuleList()
|
||||
for _ in range(len(self.convs1)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts1.append(act)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
@@ -35,12 +253,24 @@ class ResBlock(nn.Module):
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
self.acts2 = nn.ModuleList()
|
||||
for _ in range(len(self.convs2)):
|
||||
if act_fn == "snakebeta":
|
||||
act = SnakeBeta(channels, use_beta=True)
|
||||
elif act_fn == "snake":
|
||||
act = SnakeBeta(channels, use_beta=False)
|
||||
else:
|
||||
act_fn = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
|
||||
|
||||
if antialias:
|
||||
act = AntiAliasAct1d(act, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
self.acts2.append(act)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
for act1, conv1, act2, conv2 in zip(self.acts1, self.convs1, self.acts2, self.convs2):
|
||||
xt = act1(x)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = act2(xt)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
@@ -61,7 +291,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "leaky_relu",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = False,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = "tanh", # tanh, clamp, None
|
||||
final_bias: bool = True,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -69,7 +305,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.act_fn = act_fn
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
self.final_act_fn = final_act_fn
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
@@ -83,6 +321,13 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
supported_act_fns = ["snakebeta", "snake", "leaky_relu"]
|
||||
if self.act_fn not in supported_act_fns:
|
||||
raise ValueError(
|
||||
f"Unsupported activation function: {self.act_fn}. Currently supported values of `act_fn` are "
|
||||
f"{supported_act_fns}."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
@@ -103,15 +348,27 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
channels=output_channels,
|
||||
kernel_size=kernel_size,
|
||||
dilations=dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
if act_fn == "snakebeta" or act_fn == "snake":
|
||||
# Always use antialiasing
|
||||
act_out = SnakeBeta(channels=output_channels, use_beta=True)
|
||||
self.act_out = AntiAliasAct1d(act_out, ratio=antialias_ratio, kernel_size=antialias_kernel_size)
|
||||
elif act_fn == "leaky_relu":
|
||||
# NOTE: does NOT use self.negative_slope, following the original code
|
||||
self.act_out = nn.LeakyReLU()
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3, bias=final_bias)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
@@ -139,7 +396,9 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
if self.act_fn == "leaky_relu":
|
||||
# Other activations are inside each upsampling block
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
@@ -149,10 +408,190 @@ class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.act_out(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
if self.final_act_fn == "tanh":
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
elif self.final_act_fn == "clamp":
|
||||
hidden_states = torch.clamp(hidden_states, -1, 1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CausalSTFT(nn.Module):
|
||||
"""
|
||||
Performs a causal short-time Fourier transform (STFT) using causal Hann windows on a waveform. The DFT bases
|
||||
multiplied by the Hann windows are pre-calculated and stored as buffers. For exact parity with training, the exact
|
||||
buffers should be loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(self, filter_length: int = 512, hop_length: int = 80, window_length: int = 512):
|
||||
super().__init__()
|
||||
self.hop_length = hop_length
|
||||
self.window_length = window_length
|
||||
n_freqs = filter_length // 2 + 1
|
||||
|
||||
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if waveform.ndim == 2:
|
||||
waveform = waveform.unsqueeze(1) # [B, num_channels, num_samples]
|
||||
|
||||
left_pad = max(0, self.window_length - self.hop_length) # causal: left-only
|
||||
waveform = F.pad(waveform, (left_pad, 0))
|
||||
|
||||
spec = F.conv1d(waveform, self.forward_basis, stride=self.hop_length, padding=0)
|
||||
n_freqs = spec.shape[1] // 2
|
||||
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
|
||||
magnitude = torch.sqrt(real**2 + imag**2)
|
||||
phase = torch.atan2(imag.float(), real.float()).to(dtype=real.dtype)
|
||||
return magnitude, phase
|
||||
|
||||
|
||||
class MelSTFT(nn.Module):
|
||||
"""
|
||||
Calculates a causal log-mel spectrogram from a waveform. Uses a pre-calculated mel filterbank, which should be
|
||||
loaded from the checkpoint in bfloat16.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
self.stft_fn = CausalSTFT(filter_length, hop_length, window_length)
|
||||
|
||||
num_freqs = filter_length // 2 + 1
|
||||
self.register_buffer("mel_basis", torch.zeros(num_mel_channels, num_freqs), persistent=True)
|
||||
|
||||
def forward(self, waveform: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
magnitude, phase = self.stft_fn(waveform)
|
||||
energy = torch.norm(magnitude, dim=1)
|
||||
mel = torch.matmul(self.mel_basis.to(magnitude.dtype), magnitude)
|
||||
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
||||
return log_mel, magnitude, phase, energy
|
||||
|
||||
|
||||
class LTX2VocoderWithBWE(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
LTX-2.X vocoder with bandwidth extension (BWE) upsampling. The vocoder and the BWE module run in sequence, with the
|
||||
BWE module upsampling the vocoder output waveform to a higher sampling rate. The BWE module itself has the same
|
||||
architecture as the original vocoder.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1536,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: list[int] = [11, 4, 4, 4, 4, 4],
|
||||
upsample_factors: list[int] = [5, 2, 2, 2, 2, 2],
|
||||
resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
act_fn: str = "snakebeta",
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
antialias: bool = True,
|
||||
antialias_ratio: int = 2,
|
||||
antialias_kernel_size: int = 12,
|
||||
final_act_fn: str | None = None,
|
||||
final_bias: bool = False,
|
||||
bwe_in_channels: int = 128,
|
||||
bwe_hidden_channels: int = 512,
|
||||
bwe_out_channels: int = 2,
|
||||
bwe_upsample_kernel_sizes: list[int] = [12, 11, 4, 4, 4],
|
||||
bwe_upsample_factors: list[int] = [6, 5, 2, 2, 2],
|
||||
bwe_resnet_kernel_sizes: list[int] = [3, 7, 11],
|
||||
bwe_resnet_dilations: list[list[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
bwe_act_fn: str = "snakebeta",
|
||||
bwe_leaky_relu_negative_slope: float = 0.1,
|
||||
bwe_antialias: bool = True,
|
||||
bwe_antialias_ratio: int = 2,
|
||||
bwe_antialias_kernel_size: int = 12,
|
||||
bwe_final_act_fn: str | None = None,
|
||||
bwe_final_bias: bool = False,
|
||||
filter_length: int = 512,
|
||||
hop_length: int = 80,
|
||||
window_length: int = 512,
|
||||
num_mel_channels: int = 64,
|
||||
input_sampling_rate: int = 16000,
|
||||
output_sampling_rate: int = 48000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.vocoder = LTX2Vocoder(
|
||||
in_channels=in_channels,
|
||||
hidden_channels=hidden_channels,
|
||||
out_channels=out_channels,
|
||||
upsample_kernel_sizes=upsample_kernel_sizes,
|
||||
upsample_factors=upsample_factors,
|
||||
resnet_kernel_sizes=resnet_kernel_sizes,
|
||||
resnet_dilations=resnet_dilations,
|
||||
act_fn=act_fn,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
antialias=antialias,
|
||||
antialias_ratio=antialias_ratio,
|
||||
antialias_kernel_size=antialias_kernel_size,
|
||||
final_act_fn=final_act_fn,
|
||||
final_bias=final_bias,
|
||||
output_sampling_rate=input_sampling_rate,
|
||||
)
|
||||
self.bwe_generator = LTX2Vocoder(
|
||||
in_channels=bwe_in_channels,
|
||||
hidden_channels=bwe_hidden_channels,
|
||||
out_channels=bwe_out_channels,
|
||||
upsample_kernel_sizes=bwe_upsample_kernel_sizes,
|
||||
upsample_factors=bwe_upsample_factors,
|
||||
resnet_kernel_sizes=bwe_resnet_kernel_sizes,
|
||||
resnet_dilations=bwe_resnet_dilations,
|
||||
act_fn=bwe_act_fn,
|
||||
leaky_relu_negative_slope=bwe_leaky_relu_negative_slope,
|
||||
antialias=bwe_antialias,
|
||||
antialias_ratio=bwe_antialias_ratio,
|
||||
antialias_kernel_size=bwe_antialias_kernel_size,
|
||||
final_act_fn=bwe_final_act_fn,
|
||||
final_bias=bwe_final_bias,
|
||||
output_sampling_rate=output_sampling_rate,
|
||||
)
|
||||
|
||||
self.mel_stft = MelSTFT(
|
||||
filter_length=filter_length,
|
||||
hop_length=hop_length,
|
||||
window_length=window_length,
|
||||
num_mel_channels=num_mel_channels,
|
||||
)
|
||||
|
||||
self.resampler = UpSample1d(
|
||||
ratio=output_sampling_rate // input_sampling_rate,
|
||||
window_type="hann",
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, mel_spec: torch.Tensor) -> torch.Tensor:
|
||||
# 1. Run stage 1 vocoder to get low sampling rate waveform
|
||||
x = self.vocoder(mel_spec)
|
||||
batch_size, num_channels, num_samples = x.shape
|
||||
|
||||
# Pad to exact multiple of hop_length for exact mel frame count
|
||||
remainder = num_samples % self.config.hop_length
|
||||
if remainder != 0:
|
||||
x = F.pad(x, (0, self.hop_length - remainder))
|
||||
|
||||
# 2. Compute mel spectrogram on vocoder output
|
||||
mel, _, _, _ = self.mel_stft(x.flatten(0, 1))
|
||||
mel = mel.unflatten(0, (-1, num_channels))
|
||||
|
||||
# 3. Run bandwidth extender (BWE) on new mel spectrogram
|
||||
mel_for_bwe = mel.transpose(2, 3) # [B, C, num_mel_bins, num_frames] --> [B, C, num_frames, num_mel_bins]
|
||||
residual = self.bwe_generator(mel_for_bwe)
|
||||
|
||||
# 4. Residual connection with resampler
|
||||
skip = self.resampler(x)
|
||||
waveform = torch.clamp(residual + skip, -1, 1)
|
||||
output_samples = num_samples * self.config.output_sampling_rate // self.config.input_sampling_rate
|
||||
waveform = waveform[..., :output_samples]
|
||||
return waveform
|
||||
|
||||
@@ -35,6 +35,8 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel):
|
||||
# uncondition for scaling
|
||||
self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size)))
|
||||
|
||||
self.post_init()
|
||||
|
||||
def forward(self, pixel_values, return_uncond_vector=False):
|
||||
clip_output = self.model(pixel_values=pixel_values)
|
||||
latent_states = clip_output.pooler_output
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import DPMSolverMultistepScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -194,7 +194,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -213,7 +213,7 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: DPMSolverMultistepScheduler,
|
||||
):
|
||||
@@ -223,8 +223,19 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -985,14 +996,21 @@ class SanaVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -26,7 +26,7 @@ from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFa
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...loaders import SanaLoraLoaderMixin
|
||||
from ...models import AutoencoderDC, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...models import AutoencoderDC, AutoencoderKLLTX2Video, AutoencoderKLWan, SanaVideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import (
|
||||
BACKENDS_MAPPING,
|
||||
@@ -184,7 +184,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
The tokenizer used to tokenize the prompt.
|
||||
text_encoder ([`Gemma2PreTrainedModel`]):
|
||||
Text encoder model to encode the input prompts.
|
||||
vae ([`AutoencoderKLWan` or `AutoencoderDCAEV`]):
|
||||
vae ([`AutoencoderKLWan`, `AutoencoderDC`, or `AutoencoderKLLTX2Video`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
transformer ([`SanaVideoTransformer3DModel`]):
|
||||
Conditional Transformer to denoise the input latents.
|
||||
@@ -203,7 +203,7 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
self,
|
||||
tokenizer: GemmaTokenizer | GemmaTokenizerFast,
|
||||
text_encoder: Gemma2PreTrainedModel,
|
||||
vae: AutoencoderDC | AutoencoderKLWan,
|
||||
vae: AutoencoderDC | AutoencoderKLLTX2Video | AutoencoderKLWan,
|
||||
transformer: SanaVideoTransformer3DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
):
|
||||
@@ -213,8 +213,19 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
|
||||
if getattr(self, "vae", None):
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
self.vae_scale_factor_temporal = self.vae.config.temporal_compression_ratio
|
||||
self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio
|
||||
elif isinstance(self.vae, (AutoencoderDC, AutoencoderKLWan)):
|
||||
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal
|
||||
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
else:
|
||||
self.vae_scale_factor_temporal = 4
|
||||
self.vae_scale_factor_spatial = 8
|
||||
|
||||
self.vae_scale_factor = self.vae_scale_factor_spatial
|
||||
|
||||
@@ -687,14 +698,18 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
image_latents = retrieve_latents(self.vae.encode(image), sample_mode="argmax")
|
||||
image_latents = image_latents.repeat(batch_size, 1, 1, 1, 1)
|
||||
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, -1, 1, 1, 1)
|
||||
.to(image_latents.device, image_latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, -1, 1, 1, 1).to(
|
||||
image_latents.device, image_latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
_latents_mean = self.vae.latents_mean
|
||||
_latents_std = self.vae.latents_std
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
_latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
_latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
else:
|
||||
_latents_mean = torch.zeros(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
_latents_std = torch.ones(image_latents.shape[1], device=image_latents.device, dtype=image_latents.dtype)
|
||||
|
||||
latents_mean = _latents_mean.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
latents_std = 1.0 / _latents_std.view(1, -1, 1, 1, 1).to(image_latents.device, image_latents.dtype)
|
||||
image_latents = (image_latents - latents_mean) * latents_std
|
||||
|
||||
latents[:, :, 0:1] = image_latents.to(dtype)
|
||||
@@ -1034,14 +1049,21 @@ class SanaImageToVideoPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if is_torch_version(">=", "2.5.0")
|
||||
else torch_accelerator_module.OutOfMemoryError
|
||||
)
|
||||
latents_mean = (
|
||||
torch.tensor(self.vae.config.latents_mean)
|
||||
.view(1, self.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
|
||||
latents.device, latents.dtype
|
||||
)
|
||||
if isinstance(self.vae, AutoencoderKLLTX2Video):
|
||||
latents_mean = self.vae.latents_mean
|
||||
latents_std = self.vae.latents_std
|
||||
z_dim = self.vae.config.latent_channels
|
||||
elif isinstance(self.vae, AutoencoderKLWan):
|
||||
latents_mean = torch.tensor(self.vae.config.latents_mean)
|
||||
latents_std = torch.tensor(self.vae.config.latents_std)
|
||||
z_dim = self.vae.config.z_dim
|
||||
else:
|
||||
latents_mean = torch.zeros(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
latents_std = torch.ones(latents.shape[1], device=latents.device, dtype=latents.dtype)
|
||||
z_dim = latents.shape[1]
|
||||
|
||||
latents_mean = latents_mean.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = 1.0 / latents_std.view(1, z_dim, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
try:
|
||||
video = self.vae.decode(latents, return_dict=False)[0]
|
||||
|
||||
@@ -274,10 +274,14 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
# Compute in float32 (matching reference ltx_core scheduler) to avoid
|
||||
# float64 intermediates from numpy scalar / Python float promotion.
|
||||
is_numpy = isinstance(t, np.ndarray)
|
||||
t_tensor = torch.as_tensor(t, dtype=torch.float32)
|
||||
one_minus_z = 1.0 - t_tensor
|
||||
scale_factor = one_minus_z[-1] / (1.0 - self.config.shift_terminal)
|
||||
stretched_t = 1.0 - (one_minus_z / scale_factor)
|
||||
return stretched_t.numpy() if is_numpy else stretched_t
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
@@ -510,7 +514,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||
else:
|
||||
prev_sample = sample + dt * model_output
|
||||
prev_sample = sample + model_output.to(sample.dtype) * dt
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
@@ -646,7 +650,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
# Compute in float32 (matching reference ltx_core scheduler) to avoid
|
||||
# float64 intermediate precision from math.exp() + numpy promotion.
|
||||
t_tensor = torch.as_tensor(t, dtype=torch.float32)
|
||||
exp_mu = math.exp(mu)
|
||||
result = exp_mu / (exp_mu + (1 / t_tensor - 1) ** sigma)
|
||||
return result.numpy() if isinstance(t, np.ndarray) else result
|
||||
|
||||
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
@@ -19,11 +19,16 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import os
|
||||
from typing import Callable, ParamSpec, TypeVar
|
||||
|
||||
from . import logging
|
||||
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch.fft import fftn, fftshift, ifftn, ifftshift
|
||||
@@ -333,5 +338,23 @@ def disable_full_determinism():
|
||||
torch.use_deterministic_algorithms(False)
|
||||
|
||||
|
||||
@functools.wraps(functools.lru_cache)
|
||||
def lru_cache_unless_export(maxsize=128, typed=False):
|
||||
def outer_wrapper(fn: Callable[P, T]):
|
||||
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
|
||||
if is_torch_version("<", "2.7.0"):
|
||||
return cached
|
||||
|
||||
@functools.wraps(fn)
|
||||
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
|
||||
if torch.compiler.is_exporting():
|
||||
return fn(*args, **kwargs)
|
||||
return cached(*args, **kwargs)
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
return outer_wrapper
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
torch_device = get_device()
|
||||
|
||||
@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
|
||||
|
||||
from ..testing_utils import (
|
||||
floats_tensor,
|
||||
is_flaky,
|
||||
require_peft_backend,
|
||||
require_peft_version_greater,
|
||||
skip_mps,
|
||||
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
|
||||
|
||||
@require_peft_backend
|
||||
@skip_mps
|
||||
@is_flaky(max_attempts=10, description="very flaky class")
|
||||
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
pipeline_class = WanVACEPipeline
|
||||
scheduler_cls = FlowMatchEulerDiscreteScheduler
|
||||
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
|
||||
"base_dim": 3,
|
||||
"z_dim": 4,
|
||||
"dim_mult": [1, 1, 1, 1],
|
||||
"latents_mean": torch.randn(4).numpy().tolist(),
|
||||
"latents_std": torch.randn(4).numpy().tolist(),
|
||||
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
|
||||
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
|
||||
"num_res_blocks": 1,
|
||||
"temperal_downsample": [False, True, True],
|
||||
}
|
||||
|
||||
@@ -41,7 +41,6 @@ from ..testing_utils import (
|
||||
ModelOptCompileTesterMixin,
|
||||
ModelOptTesterMixin,
|
||||
ModelTesterMixin,
|
||||
PyramidAttentionBroadcastTesterMixin,
|
||||
QuantoCompileTesterMixin,
|
||||
QuantoTesterMixin,
|
||||
SingleFileTesterMixin,
|
||||
@@ -219,6 +218,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
|
||||
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"FluxTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux Transformer."""
|
||||
@@ -412,10 +415,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
|
||||
"""BitsAndBytes + compile tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
|
||||
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
|
||||
|
||||
|
||||
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
|
||||
"""FirstBlockCache tests for Flux Transformer."""
|
||||
|
||||
|
||||
@@ -13,48 +13,95 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Flux2Transformer2DModel, attention_backend
|
||||
from diffusers import Flux2Transformer2DModel
|
||||
from diffusers.models.transformers.transformer_flux2 import (
|
||||
Flux2KVAttnProcessor,
|
||||
Flux2KVCache,
|
||||
Flux2KVLayerCache,
|
||||
Flux2KVParallelSelfAttnProcessor,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
GGUFCompileTesterMixin,
|
||||
GGUFTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoCompileTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
"num_single_layers": 1,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
@@ -82,8 +129,286 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
|
||||
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for Flux2 Transformer."""
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for LoRA hotswap tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
"""Override to support dynamic height/width for compilation tests."""
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
|
||||
"""GGUF quantization tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
|
||||
"""TorchAO + compile tests for Flux2 Transformer."""
|
||||
|
||||
|
||||
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
|
||||
"""GGUF + compile tests for Flux2 Transformer."""
|
||||
|
||||
@property
|
||||
def gguf_filename(self):
|
||||
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
|
||||
|
||||
@property
|
||||
def torch_dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
def get_dummy_inputs(self):
|
||||
"""Override to provide inputs matching the real FLUX2 model dimensions.
|
||||
|
||||
Flux2 defaults: in_channels=128, joint_attention_dim=15360
|
||||
"""
|
||||
batch_size = 1
|
||||
height = 64
|
||||
width = 64
|
||||
sequence_length = 512
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
|
||||
)
|
||||
|
||||
# Flux2 uses 4D image/text IDs (t, h, w, l)
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
|
||||
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
|
||||
num_ref_tokens = 4
|
||||
|
||||
@property
|
||||
def model_class(self):
|
||||
return Flux2Transformer2DModel
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def uses_custom_attn_processor(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -91,72 +416,210 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 2,
|
||||
"joint_attention_dim": 32,
|
||||
"timestep_guidance_channels": 256, # Hardcoded in original code
|
||||
"timestep_guidance_channels": 256,
|
||||
"axes_dims_rope": [4, 4, 4, 4],
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
num_ref_tokens = self.num_ref_tokens
|
||||
|
||||
# TODO (Daniel, Sayak): We can remove this test.
|
||||
def test_flux2_consistency(self, seed=0):
|
||||
torch.manual_seed(seed)
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
ref_hidden_states = randn_tensor(
|
||||
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
img_hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
model = self.model_class(**init_dict)
|
||||
# state_dict = model.state_dict()
|
||||
# for key, param in state_dict.items():
|
||||
# print(f"{key} | {param.shape}")
|
||||
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
|
||||
ref_t_coords = torch.arange(1)
|
||||
ref_h_coords = torch.arange(num_ref_tokens)
|
||||
ref_w_coords = torch.arange(1)
|
||||
ref_l_coords = torch.arange(1)
|
||||
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
|
||||
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
t_coords = torch.arange(1)
|
||||
h_coords = torch.arange(height)
|
||||
w_coords = torch.arange(width)
|
||||
l_coords = torch.arange(1)
|
||||
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
|
||||
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
image_ids = torch.cat([ref_ids, image_ids], dim=1)
|
||||
|
||||
text_t_coords = torch.arange(1)
|
||||
text_h_coords = torch.arange(1)
|
||||
text_w_coords = torch.arange(1)
|
||||
text_l_coords = torch.arange(sequence_length)
|
||||
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
|
||||
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
|
||||
"""KV cache tests for Flux2 Transformer."""
|
||||
|
||||
def test_kv_layer_cache_store_and_get(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
k = torch.randn(1, 4, 2, 16)
|
||||
v = torch.randn(1, 4, 2, 16)
|
||||
cache.store(k, v)
|
||||
k_out, v_out = cache.get()
|
||||
assert torch.equal(k, k_out)
|
||||
assert torch.equal(v, v_out)
|
||||
|
||||
def test_kv_layer_cache_get_before_store_raises(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
try:
|
||||
cache.get()
|
||||
assert False, "Expected RuntimeError"
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
def test_kv_layer_cache_clear(self):
|
||||
cache = Flux2KVLayerCache()
|
||||
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.k_ref is None
|
||||
assert cache.v_ref is None
|
||||
|
||||
def test_kv_cache_structure(self):
|
||||
num_double = 3
|
||||
num_single = 2
|
||||
cache = Flux2KVCache(num_double, num_single)
|
||||
assert len(cache.double_block_caches) == num_double
|
||||
assert len(cache.single_block_caches) == num_single
|
||||
assert cache.num_ref_tokens == 0
|
||||
|
||||
for i in range(num_double):
|
||||
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
|
||||
for i in range(num_single):
|
||||
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
|
||||
|
||||
def test_kv_cache_clear(self):
|
||||
cache = Flux2KVCache(2, 1)
|
||||
cache.num_ref_tokens = 4
|
||||
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
|
||||
cache.clear()
|
||||
assert cache.num_ref_tokens == 0
|
||||
assert cache.get_double(0).k_ref is None
|
||||
|
||||
def _set_kv_attn_processors(self, model):
|
||||
for block in model.transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVAttnProcessor())
|
||||
for block in model.single_transformer_blocks:
|
||||
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_returns_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
self._set_kv_attn_processors(model)
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
assert output.kv_cache is not None
|
||||
assert isinstance(output.kv_cache, Flux2KVCache)
|
||||
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
|
||||
|
||||
for layer_cache in output.kv_cache.double_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
for layer_cache in output.kv_cache.single_block_caches:
|
||||
assert layer_cache.k_ref is not None
|
||||
assert layer_cache.v_ref is not None
|
||||
|
||||
@torch.no_grad()
|
||||
def test_extract_mode_output_shape(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with attention_backend("native"):
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
height, width = 4, 4
|
||||
output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.to_tuple()[0]
|
||||
assert output.sample.shape == (1, height * width, 4)
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
@torch.no_grad()
|
||||
def test_cached_mode_uses_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
# input & output have to have the same shape
|
||||
input_tensor = inputs_dict[self.main_input_name]
|
||||
expected_shape = input_tensor.shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
height, width = 4, 4
|
||||
extract_output = model(
|
||||
**self.get_dummy_inputs(height=height, width=width),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
)
|
||||
|
||||
# Check against expected slice
|
||||
# fmt: off
|
||||
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
|
||||
# fmt: on
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
|
||||
cached_output = model(
|
||||
**cached_inputs,
|
||||
kv_cache=extract_output.kv_cache,
|
||||
kv_cache_mode="cached",
|
||||
)
|
||||
|
||||
flat_output = output.cpu().flatten()
|
||||
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
|
||||
assert cached_output.sample.shape == (1, height * width, 4)
|
||||
assert cached_output.kv_cache is None
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"Flux2Transformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
@torch.no_grad()
|
||||
def test_extract_return_dict_false(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
output = model(
|
||||
**self.get_dummy_inputs(),
|
||||
kv_cache_mode="extract",
|
||||
num_ref_tokens=self.num_ref_tokens,
|
||||
ref_fixed_timestep=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
assert isinstance(output, tuple)
|
||||
assert len(output) == 2
|
||||
assert isinstance(output[1], Flux2KVCache)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
@torch.no_grad()
|
||||
def test_no_kv_cache_mode_returns_no_cache(self):
|
||||
model = self.model_class(**self.get_init_dict())
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
base_config = Flux2TransformerTesterConfig()
|
||||
output = model(**base_config.get_dummy_inputs())
|
||||
|
||||
|
||||
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = Flux2Transformer2DModel
|
||||
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
assert output.kv_cache is None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,49 +12,84 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import QwenImageTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_qwenimage import compute_text_seq_len_from_mask
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin
|
||||
from ..testing_utils import (
|
||||
AttentionTesterMixin,
|
||||
BaseModelTesterConfig,
|
||||
BitsAndBytesTesterMixin,
|
||||
ContextParallelTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
MemoryTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchAoTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.7, 0.6, 0.6]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return QwenImageTransformer2DModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
def output_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
def input_shape(self) -> tuple[int, int]:
|
||||
return (16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=4, width=4):
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.7, 0.6, 0.6]
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict[str, int | list[int]]:
|
||||
return {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 7
|
||||
height = width = 4
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
@@ -70,89 +104,57 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"joint_attention_dim": 16,
|
||||
"guidance_embeds": False,
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
}
|
||||
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_infers_text_seq_len_from_mask(self):
|
||||
"""Test that compute_text_seq_len_from_mask correctly infers sequence lengths and returns tensors."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Test 1: Contiguous mask with padding at the end (only first 2 tokens valid)
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask[:, 2:] = 0 # Only first 2 tokens are valid
|
||||
encoder_hidden_states_mask[:, 2:] = 0
|
||||
|
||||
rope_text_seq_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
# Verify rope_text_seq_len is returned as an int (for torch.compile compatibility)
|
||||
self.assertIsInstance(rope_text_seq_len, int)
|
||||
assert isinstance(rope_text_seq_len, int)
|
||||
assert isinstance(per_sample_len, torch.Tensor)
|
||||
assert int(per_sample_len.max().item()) == 2
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
assert normalized_mask.sum().item() == 2
|
||||
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
|
||||
|
||||
# Verify per_sample_len is computed correctly (max valid position + 1 = 2)
|
||||
self.assertIsInstance(per_sample_len, torch.Tensor)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 2)
|
||||
|
||||
# Verify mask is normalized to bool dtype
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
self.assertEqual(normalized_mask.sum().item(), 2) # Only 2 True values
|
||||
|
||||
# Verify rope_text_seq_len is at least the sequence length
|
||||
self.assertGreaterEqual(rope_text_seq_len, inputs["encoder_hidden_states"].shape[1])
|
||||
|
||||
# Test 2: Verify model runs successfully with inferred values
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Different mask pattern (padding at beginning)
|
||||
encoder_hidden_states_mask2 = inputs["encoder_hidden_states_mask"].clone()
|
||||
encoder_hidden_states_mask2[:, :3] = 0 # First 3 tokens are padding
|
||||
encoder_hidden_states_mask2[:, 3:] = 1 # Last 4 tokens are valid
|
||||
encoder_hidden_states_mask2[:, :3] = 0
|
||||
encoder_hidden_states_mask2[:, 3:] = 1
|
||||
|
||||
rope_text_seq_len2, per_sample_len2, normalized_mask2 = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask2
|
||||
)
|
||||
|
||||
# Max valid position is 6 (last token), so per_sample_len should be 7
|
||||
self.assertEqual(int(per_sample_len2.max().item()), 7)
|
||||
self.assertEqual(normalized_mask2.sum().item(), 4) # 4 True values
|
||||
assert int(per_sample_len2.max().item()) == 8
|
||||
assert normalized_mask2.sum().item() == 5
|
||||
|
||||
# Test 4: No mask provided (None case)
|
||||
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], None
|
||||
)
|
||||
self.assertEqual(rope_text_seq_len_none, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(rope_text_seq_len_none, int)
|
||||
self.assertIsNone(per_sample_len_none)
|
||||
self.assertIsNone(normalized_mask_none)
|
||||
assert rope_text_seq_len_none == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(rope_text_seq_len_none, int)
|
||||
assert per_sample_len_none is None
|
||||
assert normalized_mask_none is None
|
||||
|
||||
def test_non_contiguous_attention_mask(self):
|
||||
"""Test that non-contiguous masks work correctly (e.g., [1, 0, 1, 0, 1, 0, 0])"""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Create a non-contiguous mask pattern: valid, padding, valid, padding, etc.
|
||||
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
|
||||
# Pattern: [True, False, True, False, True, False, False]
|
||||
encoder_hidden_states_mask[:, 1] = 0
|
||||
encoder_hidden_states_mask[:, 3] = 0
|
||||
encoder_hidden_states_mask[:, 5:] = 0
|
||||
@@ -160,95 +162,85 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
inferred_rope_len, per_sample_len, normalized_mask = compute_text_seq_len_from_mask(
|
||||
inputs["encoder_hidden_states"], encoder_hidden_states_mask
|
||||
)
|
||||
self.assertEqual(int(per_sample_len.max().item()), 5)
|
||||
self.assertEqual(inferred_rope_len, inputs["encoder_hidden_states"].shape[1])
|
||||
self.assertIsInstance(inferred_rope_len, int)
|
||||
self.assertTrue(normalized_mask.dtype == torch.bool)
|
||||
assert int(per_sample_len.max().item()) == 5
|
||||
assert inferred_rope_len == inputs["encoder_hidden_states"].shape[1]
|
||||
assert isinstance(inferred_rope_len, int)
|
||||
assert normalized_mask.dtype == torch.bool
|
||||
|
||||
inputs["encoder_hidden_states_mask"] = normalized_mask
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_txt_seq_lens_deprecation(self):
|
||||
"""Test that passing txt_seq_lens raises a deprecation warning."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Prepare inputs with txt_seq_lens (deprecated parameter)
|
||||
txt_seq_lens = [inputs["encoder_hidden_states"].shape[1]]
|
||||
|
||||
# Remove encoder_hidden_states_mask to use the deprecated path
|
||||
inputs_with_deprecated = inputs.copy()
|
||||
inputs_with_deprecated.pop("encoder_hidden_states_mask")
|
||||
inputs_with_deprecated["txt_seq_lens"] = txt_seq_lens
|
||||
|
||||
# Test that deprecation warning is raised
|
||||
with self.assertWarns(FutureWarning) as warning_context:
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_with_deprecated)
|
||||
|
||||
# Verify the warning message mentions the deprecation
|
||||
warning_message = str(warning_context.warning)
|
||||
self.assertIn("txt_seq_lens", warning_message)
|
||||
self.assertIn("deprecated", warning_message)
|
||||
self.assertIn("encoder_hidden_states_mask", warning_message)
|
||||
future_warnings = [x for x in w if issubclass(x.category, FutureWarning)]
|
||||
assert len(future_warnings) > 0, "Expected FutureWarning to be raised"
|
||||
|
||||
# Verify the model still works correctly despite the deprecation
|
||||
self.assertEqual(output.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
warning_message = str(future_warnings[0].message)
|
||||
assert "txt_seq_lens" in warning_message
|
||||
assert "deprecated" in warning_message
|
||||
|
||||
assert output.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
def test_layered_model_with_mask(self):
|
||||
"""Test QwenImageTransformer2DModel with use_layer3d_rope=True (layered model)."""
|
||||
# Create layered model config
|
||||
init_dict = {
|
||||
"patch_size": 2,
|
||||
"in_channels": 16,
|
||||
"out_channels": 4,
|
||||
"num_layers": 2,
|
||||
"attention_head_dim": 16,
|
||||
"num_attention_heads": 3,
|
||||
"num_attention_heads": 4,
|
||||
"joint_attention_dim": 16,
|
||||
"axes_dims_rope": (8, 4, 4), # Must match attention_head_dim (8+4+4=16)
|
||||
"use_layer3d_rope": True, # Enable layered RoPE
|
||||
"use_additional_t_cond": True, # Enable additional time conditioning
|
||||
"axes_dims_rope": (8, 4, 4),
|
||||
"use_layer3d_rope": True,
|
||||
"use_additional_t_cond": True,
|
||||
}
|
||||
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
|
||||
# Verify the model uses QwenEmbedLayer3DRope
|
||||
from diffusers.models.transformers.transformer_qwenimage import QwenEmbedLayer3DRope
|
||||
|
||||
self.assertIsInstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
assert isinstance(model.pos_embed, QwenEmbedLayer3DRope)
|
||||
|
||||
# Test single generation with layered structure
|
||||
batch_size = 1
|
||||
text_seq_len = 7
|
||||
text_seq_len = 8
|
||||
img_h, img_w = 4, 4
|
||||
layers = 4
|
||||
|
||||
# For layered model: (layers + 1) because we have N layers + 1 combined image
|
||||
hidden_states = torch.randn(batch_size, (layers + 1) * img_h * img_w, 16).to(torch_device)
|
||||
encoder_hidden_states = torch.randn(batch_size, text_seq_len, 16).to(torch_device)
|
||||
|
||||
# Create mask with some padding
|
||||
encoder_hidden_states_mask = torch.ones(batch_size, text_seq_len).to(torch_device)
|
||||
encoder_hidden_states_mask[0, 5:] = 0 # Only 5 valid tokens
|
||||
encoder_hidden_states_mask[0, 5:] = 0
|
||||
|
||||
timestep = torch.tensor([1.0]).to(torch_device)
|
||||
|
||||
# additional_t_cond for use_additional_t_cond=True (0 or 1 index for embedding)
|
||||
addition_t_cond = torch.tensor([0], dtype=torch.long).to(torch_device)
|
||||
|
||||
# Layer structure: 4 layers + 1 condition image
|
||||
img_shapes = [
|
||||
[
|
||||
(1, img_h, img_w), # layer 0
|
||||
(1, img_h, img_w), # layer 1
|
||||
(1, img_h, img_w), # layer 2
|
||||
(1, img_h, img_w), # layer 3
|
||||
(1, img_h, img_w), # condition image (last one gets special treatment)
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
(1, img_h, img_w),
|
||||
]
|
||||
]
|
||||
|
||||
@@ -262,37 +254,113 @@ class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
additional_t_cond=addition_t_cond,
|
||||
)
|
||||
|
||||
self.assertEqual(output.sample.shape[1], hidden_states.shape[1])
|
||||
assert output.sample.shape[1] == hidden_states.shape[1]
|
||||
|
||||
|
||||
class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = QwenImageTransformer2DModel
|
||||
class TestQwenImageTransformerMemory(QwenImageTransformerTesterConfig, MemoryTesterMixin):
|
||||
"""Memory optimization tests for QwenImage Transformer."""
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return QwenImageTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
def prepare_dummy_input(self, height, width):
|
||||
return QwenImageTransformerTests().prepare_dummy_input(height=height, width=width)
|
||||
class TestQwenImageTransformerTraining(QwenImageTransformerTesterConfig, TrainingTesterMixin):
|
||||
"""Training tests for QwenImage Transformer."""
|
||||
|
||||
def test_torch_compile_recompilation_and_graph_break(self):
|
||||
super().test_torch_compile_recompilation_and_graph_break()
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"QwenImageTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestQwenImageTransformerAttention(QwenImageTransformerTesterConfig, AttentionTesterMixin):
|
||||
"""Attention processor tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerContextParallel(QwenImageTransformerTesterConfig, ContextParallelTesterMixin):
|
||||
"""Context Parallel inference tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterMixin):
|
||||
"""LoRA adapter tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
"""LoRA hot-swapping tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
|
||||
class TestQwenImageTransformerCompile(QwenImageTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
"""Torch compile tests for QwenImage Transformer."""
|
||||
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = embedding_dim = 16
|
||||
sequence_length = 8
|
||||
vae_scale_factor = 4
|
||||
|
||||
hidden_states = randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
encoder_hidden_states_mask = torch.ones((batch_size, sequence_length)).to(torch_device, torch.long)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
orig_height = height * 2 * vae_scale_factor
|
||||
orig_width = width * 2 * vae_scale_factor
|
||||
img_shapes = [(1, orig_height // vae_scale_factor // 2, orig_width // vae_scale_factor // 2)] * batch_size
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_hidden_states_mask": encoder_hidden_states_mask,
|
||||
"timestep": timestep,
|
||||
"img_shapes": img_shapes,
|
||||
}
|
||||
|
||||
def test_torch_compile_with_and_without_mask(self):
|
||||
"""Test that torch.compile works with both None mask and padding mask."""
|
||||
init_dict, inputs = self.prepare_init_args_and_inputs_for_common()
|
||||
init_dict = self.get_init_dict()
|
||||
inputs = self.get_dummy_inputs()
|
||||
model = self.model_class(**init_dict).to(torch_device)
|
||||
model.eval()
|
||||
model.compile(mode="default", fullgraph=True)
|
||||
|
||||
# Test 1: Run with None mask (no padding, all tokens are valid)
|
||||
inputs_no_mask = inputs.copy()
|
||||
inputs_no_mask["encoder_hidden_states_mask"] = None
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_no_mask = model(**inputs_no_mask)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -300,19 +368,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_no_mask_2 = model(**inputs_no_mask)
|
||||
|
||||
self.assertEqual(output_no_mask.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_no_mask_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_no_mask.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_no_mask_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 2: Run with all-ones mask (should behave like None)
|
||||
inputs_all_ones = inputs.copy()
|
||||
# Keep the all-ones mask
|
||||
self.assertTrue(inputs_all_ones["encoder_hidden_states_mask"].all().item())
|
||||
assert inputs_all_ones["encoder_hidden_states_mask"].all().item()
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_all_ones = model(**inputs_all_ones)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -320,21 +384,18 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_all_ones_2 = model(**inputs_all_ones)
|
||||
|
||||
self.assertEqual(output_all_ones.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_all_ones_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_all_ones.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_all_ones_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Test 3: Run with actual padding mask (has zeros)
|
||||
inputs_with_padding = inputs.copy()
|
||||
mask_with_padding = inputs["encoder_hidden_states_mask"].clone()
|
||||
mask_with_padding[:, 4:] = 0 # Last 3 tokens are padding
|
||||
mask_with_padding[:, 4:] = 0
|
||||
|
||||
inputs_with_padding["encoder_hidden_states_mask"] = mask_with_padding
|
||||
|
||||
# First run to allow compilation
|
||||
with torch.no_grad():
|
||||
output_with_padding = model(**inputs_with_padding)
|
||||
|
||||
# Second run to verify no recompilation
|
||||
with (
|
||||
torch._inductor.utils.fresh_inductor_cache(),
|
||||
torch._dynamo.config.patch(error_on_recompile=True),
|
||||
@@ -342,8 +403,15 @@ class QwenImageTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCas
|
||||
):
|
||||
output_with_padding_2 = model(**inputs_with_padding)
|
||||
|
||||
self.assertEqual(output_with_padding.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
self.assertEqual(output_with_padding_2.sample.shape[1], inputs["hidden_states"].shape[1])
|
||||
assert output_with_padding.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
assert output_with_padding_2.sample.shape[1] == inputs["hidden_states"].shape[1]
|
||||
|
||||
# Verify that outputs are different (mask should affect results)
|
||||
self.assertFalse(torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3))
|
||||
assert not torch.allclose(output_no_mask.sample, output_with_padding.sample, atol=1e-3)
|
||||
|
||||
|
||||
class TestQwenImageTransformerBitsAndBytes(QwenImageTransformerTesterConfig, BitsAndBytesTesterMixin):
|
||||
"""BitsAndBytes quantization tests for QwenImage Transformer."""
|
||||
|
||||
|
||||
class TestQwenImageTransformerTorchAo(QwenImageTransformerTesterConfig, TorchAoTesterMixin):
|
||||
"""TorchAO quantization tests for QwenImage Transformer."""
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
|
||||
@@ -32,6 +33,33 @@ from ..testing_utils import (
|
||||
)
|
||||
|
||||
|
||||
def _get_specified_components(path_or_repo_id, cache_dir=None):
|
||||
if os.path.isdir(path_or_repo_id):
|
||||
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
|
||||
else:
|
||||
try:
|
||||
config_path = hf_hub_download(
|
||||
repo_id=path_or_repo_id,
|
||||
filename="modular_model_index.json",
|
||||
local_dir=cache_dir,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
components = set()
|
||||
for k, v in config.items():
|
||||
if isinstance(v, (str, int, float, bool)):
|
||||
continue
|
||||
for entry in v:
|
||||
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
|
||||
components.add(k)
|
||||
break
|
||||
return components
|
||||
|
||||
|
||||
class ModularPipelineTesterMixin:
|
||||
"""
|
||||
It provides a set of common tests for each modular pipeline,
|
||||
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
|
||||
|
||||
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
def test_load_expected_components_from_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
|
||||
if not expected:
|
||||
pytest.skip("Skipping test as we couldn't fetch the expected components.")
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in pipe.components
|
||||
if getattr(pipe, name, None) is not None
|
||||
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
|
||||
|
||||
def test_load_expected_components_from_save_pretrained(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
save_dir = str(tmp_path / "saved-pipeline")
|
||||
pipe.save_pretrained(save_dir)
|
||||
|
||||
expected = _get_specified_components(save_dir)
|
||||
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
|
||||
loaded_pipe.load_components(torch_dtype=torch.float32)
|
||||
|
||||
actual = {
|
||||
name
|
||||
for name in loaded_pipe.components
|
||||
if getattr(loaded_pipe, name, None) is not None
|
||||
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
|
||||
}
|
||||
assert expected == actual, (
|
||||
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
|
||||
)
|
||||
|
||||
def test_modular_index_consistency(self, tmp_path):
|
||||
pipe = self.get_pipeline()
|
||||
components_spec = pipe._component_specs
|
||||
|
||||
@@ -31,7 +31,41 @@ from diffusers.modular_pipelines import (
|
||||
WanModularPipeline,
|
||||
)
|
||||
|
||||
from ..testing_utils import nightly, require_torch, slow
|
||||
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
|
||||
|
||||
|
||||
def _create_tiny_model_dir(model_dir):
|
||||
TINY_MODEL_CODE = (
|
||||
"import torch\n"
|
||||
"from diffusers import ModelMixin, ConfigMixin\n"
|
||||
"from diffusers.configuration_utils import register_to_config\n"
|
||||
"\n"
|
||||
"class TinyModel(ModelMixin, ConfigMixin):\n"
|
||||
" @register_to_config\n"
|
||||
" def __init__(self, hidden_size=4):\n"
|
||||
" super().__init__()\n"
|
||||
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
|
||||
"\n"
|
||||
" def forward(self, x):\n"
|
||||
" return self.linear(x)\n"
|
||||
)
|
||||
|
||||
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
|
||||
f.write(TINY_MODEL_CODE)
|
||||
|
||||
config = {
|
||||
"_class_name": "TinyModel",
|
||||
"_diffusers_version": "0.0.0",
|
||||
"auto_map": {"AutoModel": "modeling.TinyModel"},
|
||||
"hidden_size": 4,
|
||||
}
|
||||
with open(os.path.join(model_dir, "config.json"), "w") as f:
|
||||
json.dump(config, f)
|
||||
|
||||
torch.save(
|
||||
{"linear.weight": torch.randn(4, 4), "linear.bias": torch.randn(4)},
|
||||
os.path.join(model_dir, "diffusion_pytorch_model.bin"),
|
||||
)
|
||||
|
||||
|
||||
class DummyCustomBlockSimple(ModularPipelineBlocks):
|
||||
@@ -341,6 +375,81 @@ class TestModularCustomBlocks:
|
||||
loaded_pipe.update_components(custom_model=custom_model)
|
||||
assert getattr(loaded_pipe, "custom_model", None) is not None
|
||||
|
||||
def test_automodel_type_hint_preserves_torch_dtype(self, tmp_path):
|
||||
"""Regression test for #13271: torch_dtype was incorrectly removed when type_hint is AutoModel."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DtypeTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DtypeTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(torch_dtype=torch.float16, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.dtype == torch.float16
|
||||
|
||||
@require_torch_accelerator
|
||||
def test_automodel_type_hint_preserves_device(self, tmp_path):
|
||||
"""Test that ComponentSpec with AutoModel type_hint correctly passes device_map."""
|
||||
from diffusers import AutoModel
|
||||
|
||||
model_dir = str(tmp_path / "model")
|
||||
os.makedirs(model_dir)
|
||||
_create_tiny_model_dir(model_dir)
|
||||
|
||||
class DeviceTestBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [ComponentSpec("model", AutoModel, pretrained_model_name_or_path=model_dir)]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [InputParam("prompt", type_hint=str, required=True)]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return []
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [OutputParam("output", type_hint=str)]
|
||||
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.output = "test"
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
block = DeviceTestBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.load_components(device_map=torch_device, trust_remote_code=True)
|
||||
|
||||
assert pipe.model.device.type == torch_device
|
||||
|
||||
def test_custom_block_loads_from_hub(self):
|
||||
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
|
||||
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)
|
||||
|
||||
@@ -139,9 +139,9 @@ class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
|
||||
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
|
||||
|
||||
# Override to set a more lenient max diff threshold.
|
||||
@unittest.skip("Helios uses a lot of mixed precision internally, which is not suitable for this test case")
|
||||
def test_save_load_float16(self):
|
||||
super().test_save_load_float16(expected_max_diff=0.03)
|
||||
pass
|
||||
|
||||
@unittest.skip("Test not supported")
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
|
||||
@@ -171,6 +171,7 @@ class LTX2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
@@ -171,6 +171,7 @@ class LTX2ImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
"tokenizer": tokenizer,
|
||||
"connectors": connectors,
|
||||
"vocoder": vocoder,
|
||||
"processor": None,
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
Reference in New Issue
Block a user