mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-29 20:07:48 +08:00
Compare commits
3 Commits
main
...
ltx23-pari
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7a215eca60 | ||
|
|
c3c9555db8 | ||
|
|
5dde9fc179 |
@@ -13,12 +13,15 @@ Before writing any test code, gather:
|
||||
|
||||
1. **Which two implementations** are being compared (e.g. research repo → diffusers, standard → modular, or research → modular). Use `AskUserQuestion` with structured choices if not already clear.
|
||||
2. **Two equivalent runnable scripts** — one for each implementation, both expected to produce identical output given the same inputs. These scripts define what "parity" means concretely.
|
||||
3. **Test directory**: Ask the user if they have a preferred directory for parity test scripts and artifacts. If not, create `parity-tests/` at the repo root.
|
||||
4. **Lab book**: Ask the user if they want to maintain a `lab_book.md` in the test directory to track findings, fixes, and experiment results across sessions. This is especially useful for multi-session debugging where context gets lost.
|
||||
|
||||
When invoked from the `model-integration` skill, you already have context: the reference script comes from step 2 of setup, and the diffusers script is the one you just wrote. You just need to make sure both scripts are runnable and use the same inputs/seed/params.
|
||||
|
||||
## Test strategy
|
||||
## Phase 1: CPU/float32 parity (always run)
|
||||
|
||||
### Component parity — test as you build
|
||||
|
||||
**Component parity (CPU/float32) -- always run, as you build.**
|
||||
Test each component before assembling the pipeline. This is the foundation -- if individual pieces are wrong, the pipeline can't be right. Each component in isolation, strict max_diff < 1e-3.
|
||||
|
||||
Test freshly converted checkpoints and saved checkpoints.
|
||||
@@ -27,6 +30,22 @@ Test freshly converted checkpoints and saved checkpoints.
|
||||
|
||||
Keep component test scripts around -- you will need to re-run them during pipeline debugging with different inputs or config values.
|
||||
|
||||
**Write a model interface mapping** as you test each component. This documents every input difference between reference and diffusers models — format, dtype, shape, who computes what. Save it in the test directory (e.g., `parity-tests/model_interface_mapping.md`). This is critical: during pipeline testing, you MUST reference this mapping to verify the pipeline passes inputs in the correct format. Without it, you'll waste time rediscovering differences you already found.
|
||||
|
||||
Example mapping (from LTX-2.3):
|
||||
```markdown
|
||||
| Input | Reference | Diffusers | Notes |
|
||||
|---|---|---|---|
|
||||
| timestep | per-token bf16 sigma, scaled by 1000 internally | passed as sigma*1000 | shape (B,S) not (B,) |
|
||||
| sigma (prompt_adaln) | raw f32 sigma, scaled internally | passed as sigma*1000 in f32 | NOT bf16 |
|
||||
| positions/coords | computed inside model preprocessor | passed as kwarg video_coords | cast to model dtype |
|
||||
| cross-attn timestep | always cross_modality.sigma | always audio_sigma | not conditional |
|
||||
| encoder_attention_mask | None (no mask) | None or all-ones | all-ones triggers different SDPA kernel |
|
||||
| RoPE | computed in model dtype (no upcast) | must match — no float32 upcast | cos/sin cast to input dtype |
|
||||
| output format | X0Model returns x0 | transformer returns velocity | v→x0: (sample - vel * sigma) |
|
||||
| audio output | .squeeze(0).float() | must match | (2,N) float32 not (1,2,N) bf16 |
|
||||
```
|
||||
|
||||
Template -- one self-contained script per component, reference and diffusers side-by-side:
|
||||
```python
|
||||
@torch.inference_mode()
|
||||
@@ -57,25 +76,25 @@ def test_my_component(mode="fresh", model_path=None):
|
||||
```
|
||||
Key points: (a) both reference and diffusers component in one script -- never split into separate scripts that save/load intermediates, (b) deterministic input via seeded generator, (c) load one model at a time to fit in CPU RAM, (d) `.clone()` the reference output before deleting the model.
|
||||
|
||||
**E2E visual (GPU/bfloat16) -- once the pipeline is assembled.**
|
||||
Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, you're done -- no need for deeper testing.
|
||||
### Pipeline stage tests — encode, decode, then denoise
|
||||
|
||||
**Pipeline stage tests -- only if E2E fails and you need to isolate the bug.**
|
||||
If the user already suspects where divergence is, start there. Otherwise, work through stages in order.
|
||||
Use the capture-inject checkpoint method (see [checkpoint-mechanism.md](checkpoint-mechanism.md)) to test each pipeline stage independently. This methodology is the same for both CPU/float32 and GPU/bf16.
|
||||
|
||||
First, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match. Check how noise is initialized in the diffusers script — if it doesn't match the reference, temporarily change it to match. Note what you changed so it can be reverted after parity is confirmed.
|
||||
Before writing pipeline tests, **review the model interface mapping** from the component test phase and verify them. The mapping tells you which differences between the two models are expected (e.g., reference expects raw sigma but diffusers expects sigma*1000). Without it, you'll waste time investigating differences that are by design, not bugs.
|
||||
|
||||
For small models, run on CPU/float32 for strict comparison. For large models (e.g. 22B params), CPU/float32 is impractical -- use GPU/bfloat16 with `enable_model_cpu_offload()` and relax tolerances (max_diff < 1e-1 for bfloat16 is typical for passing tests; cosine similarity > 0.9999 is a good secondary check).
|
||||
First, **match noise generation**: the way initial noise/latents are constructed (seed handling, generator, randn call order) often differs between the two scripts. If the noise doesn't match, nothing downstream will match.
|
||||
|
||||
Test encode and decode stages first -- they're simpler and bugs there are easier to fix. Only debug the denoising loop if encode and decode both pass.
|
||||
|
||||
The challenge: pipelines are monolithic `__call__` methods -- you can't just call "the encode part". See [checkpoint-mechanism.md](checkpoint-mechanism.md) for the checkpoint class that lets you stop, save, or inject tensors at named locations inside the pipeline.
|
||||
|
||||
**Stage test order — encode, decode, then denoise:**
|
||||
**Stage test order:**
|
||||
|
||||
- **`encode`** (test first): Stop both pipelines at `"preloop"`. Compare **every single variable** that will be consumed by the denoising loop -- not just latents and sigmas, but also prompt embeddings, attention masks, positional coordinates, connector outputs, and any conditioning inputs.
|
||||
- **`decode`** (test second, before denoise): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the decoded output. Then feed those same post-loop latents through the diffusers pipeline's decode path. Compare both numerically AND visually.
|
||||
- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps, but stop after 2 loop iterations using `after_step_1`. Don't set `num_steps=2` -- that produces unrealistic sigma schedules.
|
||||
- **`decode`** (test second): Run the reference pipeline fully -- checkpoint the post-loop latents AND let it finish to get the **final output**. Feed those same post-loop latents through the diffusers decode path. Compare the **final output format** -- not raw tensors, but what the user actually gets:
|
||||
- **Image**: compare PIL.Image pixels
|
||||
- **Video**: compare through the pipeline's export function (e.g. `encode_video`)
|
||||
- **Video+Audio**: compare video frames AND audio waveform through `encode_video`
|
||||
- This catches postprocessing bugs like float→uint8 rounding, audio format, and codec settings.
|
||||
- **`denoise`** (test last): Run both pipelines with realistic `num_steps` (e.g. 30) so the scheduler computes correct sigmas/timesteps. For float32, stop after 2 loop iterations using `after_step_1` (don't set `num_steps=2` -- that produces unrealistic sigma schedules). For bf16, run ALL steps (see Phase 2).
|
||||
|
||||
Start with coarse checkpoints (`after_step_{i}` — just the denoised latents at each step). If a step diverges, place finer checkpoints within that step (e.g. before/after model call, after CFG, after scheduler step). If the divergence is inside the model forward call, use PyTorch forward hooks (`register_forward_hook`) to capture intermediate outputs from sub-modules (e.g., attention output, feed-forward output) and compare them between the two models to find the first diverging operation.
|
||||
|
||||
```python
|
||||
# Encode stage -- stop before the loop, compare ALL inputs:
|
||||
@@ -94,7 +113,27 @@ compare_tensors("prompt_embeds", ref_data["prompt_embeds"], diff_data["prompt_em
|
||||
# ... every single tensor the transformer forward() will receive
|
||||
```
|
||||
|
||||
**E2E-injected visual test**: Once you've identified a suspected root cause using stage tests, confirm it with an e2e-injected run -- inject the known-good tensor from reference and generate a full video. If the output looks identical to reference, you've confirmed the root cause.
|
||||
### E2E visual — once stages pass
|
||||
|
||||
Both pipelines generate independently with identical seeds/params. Save outputs and compare visually. If outputs look identical, Phase 1 is done.
|
||||
|
||||
If CPU/float32 stage tests all pass and E2E outputs are identical → Phase 1 is done, move on.
|
||||
|
||||
If E2E outputs are NOT identical despite stage tests passing, **ask the user**: "CPU/float32 parity passes at the stage level but E2E output differs. The output in bf16/GPU may look slightly different from the reference due to precision casting, but the quality should be the same. Do you want to just vibe-check the output quality, or do you need 1:1 identical output with the reference in bf16?"
|
||||
|
||||
- If the user says quality looks fine → **done**.
|
||||
- If the user needs 1:1 identical output in bf16 → Phase 2.
|
||||
|
||||
## Phase 2: GPU/bf16 parity (optional — only if user needs 1:1 output)
|
||||
|
||||
If CPU/float32 passes, the algorithm is correct. bf16 differences are from precision casting (e.g., float32 vs bf16 in RoPE, CFG arithmetic order, scheduler intermediates), not logic bugs. These can make the output look slightly different from the reference even though the quality is identical. Phase 2 eliminates these casting differences so the diffusers output is **bit-identical** to the reference in bf16.
|
||||
|
||||
Phase 2 uses the **exact same stage test methodology** as Phase 1 (encode → decode → denoise with progressive checkpoint refinement), with two differences:
|
||||
|
||||
1. **dtype=bf16, device=GPU** instead of float32/CPU
|
||||
2. **Run the FULL denoising loop** (all steps, not just 2) — bf16 casting differences accumulate over steps and may only manifest after many iterations
|
||||
|
||||
See [pitfalls.md](pitfalls.md) #19-#27 for the catalog of bf16-specific gotchas.
|
||||
|
||||
## Debugging technique: Injection for root-cause isolation
|
||||
|
||||
@@ -145,6 +184,8 @@ extract_frames(diff_video, [0, 60, 120])
|
||||
6. **Diff configs before debugging.** Before investigating any divergence, dump and compare all config values. A 30-second config diff prevents hours of debugging based on wrong assumptions.
|
||||
7. **Never modify cached/downloaded model configs directly.** Don't edit files in `~/.cache/huggingface/`. Instead, save to a local directory or open a PR on the upstream repo.
|
||||
8. **Compare ALL loop inputs in the encode test.** The preloop checkpoint must capture every single tensor the transformer forward() will receive.
|
||||
9. **Don't contaminate test paths.** Each side (reference, diffusers) must use only its own code to generate outputs. For COMPARISON, save both outputs through the SAME function (so codec/format differences don't create false diffs). Example: don't use the reference's `encode_video` for one side and diffusers' for the other.
|
||||
10. **Re-test standalone model through the actual pipeline if divergence points to the model.** If pipeline stage tests show the divergence is at the model output (e.g., `cond_x0` differs despite identical inputs), re-run the model comparison using capture-inject with real pipeline-generated inputs. Standalone model tests use manually constructed kwargs which may have wrong config values, dtypes, or shapes — the pipeline generates the real ones.
|
||||
|
||||
## Comparison utilities
|
||||
|
||||
@@ -165,6 +206,11 @@ def compare_tensors(name: str, a: torch.Tensor, b: torch.Tensor, tol: float = 1e
|
||||
```
|
||||
Cosine similarity is especially useful for GPU/bfloat16 tests where max_diff can be noisy -- `cos > 0.9999` is a strong signal even when max_diff exceeds tolerance.
|
||||
|
||||
## Example scripts
|
||||
|
||||
- [examples/test_component_parity_cpu.py](examples/test_component_parity_cpu.py) — Template for CPU/float32 component parity test
|
||||
- [examples/test_e2e_bf16_parity.py](examples/test_e2e_bf16_parity.py) — Template for GPU/bf16 E2E parity test with capture-inject
|
||||
|
||||
## Gotchas
|
||||
|
||||
See [pitfalls.md](pitfalls.md) for the full list of gotchas to watch for during parity testing.
|
||||
|
||||
@@ -114,3 +114,41 @@ When running on GPU/bfloat16, multi-layer encoders (e.g. 8-layer connector trans
|
||||
## 18. Stale test fixtures
|
||||
|
||||
When using saved tensors for cross-pipeline comparison, always ensure both sets of tensors were captured from the same run configuration (same seed, same config, same code version). Mixing fixtures from different runs (e.g. reference tensors from yesterday, diffusers tensors from today after a code change) creates phantom divergence that wastes debugging time. Regenerate both sides in a single test script execution.
|
||||
|
||||
## 19. RoPE float32 upcast changes bf16 output
|
||||
|
||||
If `apply_rotary_emb` upcasts input to float32 for the rotation computation (`x.float() * cos + x_rotated.float() * sin`), but the reference stays in bf16, the results differ after casting back. The float32 intermediate produces different rounding than native bf16 computation.
|
||||
|
||||
**Fix**: Remove the float32 upcast. Cast cos/sin to the input dtype instead: `cos, sin = cos.to(x.dtype), sin.to(x.dtype)`, then compute `x * cos + x_rotated * sin` in the model's native dtype.
|
||||
|
||||
## 20. CFG formula arithmetic order
|
||||
|
||||
`cond + (scale-1) * (cond - uncond)` and `uncond + scale * (cond - uncond)` are mathematically identical but produce different bf16 results because the multiplication factor (3 vs 4 for scale=4) and the base (cond vs uncond) differ. Match the reference's exact formula.
|
||||
|
||||
## 21. Scheduler float64 intermediates from numpy
|
||||
|
||||
`math.exp(mu) / (math.exp(mu) + (1/t - 1))` where `t` is a numpy float32 array promotes to float64 (because `math.exp` returns Python float64 and numpy promotes). The reference uses torch float32. Fix: compute in `torch.float32` using `torch.as_tensor(t, dtype=torch.float32)`. Same for `np.linspace` vs `torch.linspace` — use `torch.linspace` for float32-native computation.
|
||||
|
||||
## 22. Zero-dim tensor type promotion in Euler step
|
||||
|
||||
`dt * model_output` where `dt` is a 0-dim float32 tensor and `model_output` is bf16: PyTorch treats the 0-dim tensor as a "scalar" that adapts to the tensor's dtype. Result is **bf16**, not float32. The reference does `velocity.to(float32) * dt` which is float32. Fix: explicitly upcast `model_output.to(sample.dtype) * dt`.
|
||||
|
||||
## 23. Per-token vs per-batch timestep shape
|
||||
|
||||
Passing timestep as `(B,)` produces temb shape `(B, 1, D)` via the adaln. Passing `(B, S)` produces `(B, S, D)`. For T2V where all tokens share the same sigma, these are mathematically equivalent but use different CUDA kernels with different bf16 rounding. Match the reference's shape — typically per-token `(B, S)`.
|
||||
|
||||
## 24. Model config missing fields
|
||||
|
||||
The diffusers checkpoint config may be missing fields that the reference model has (e.g. `use_cross_timestep`, `prompt_modulation`). The code falls back to a default that may be wrong. Always check the ACTUAL runtime value, not the code default. Run `getattr(model.config, "field_name", "MISSING")` and compare against the reference model's config.
|
||||
|
||||
## 25. Cross-attention timestep conditional
|
||||
|
||||
The reference may always use `cross_modality.sigma` for cross-attention timestep (e.g., video cross-attn uses audio sigma), but the diffusers model may conditionally use the main timestep based on `use_cross_timestep`. If the conditional is wrong or the config field is missing, the cross-attention receives a completely different timestep — different shape `(S,)` vs `(1,)`, different value, and different sinusoidal embedding. This is a model-level bug that standalone tests miss because they pass `use_cross_timestep` manually.
|
||||
|
||||
## 26. Audio/video output format mismatch
|
||||
|
||||
The reference may return audio as `(2, N)` float32 (after `.squeeze(0).float()`), while the diffusers pipeline returns `(1, 2, N)` bf16 from the vocoder. The `_write_audio` function in `encode_video` doesn't handle 3D tensors correctly. Fix: add `.squeeze(0).float()` after the vocoder call in the audio decoder step.
|
||||
|
||||
## 27. encode_video float-to-uint8 rounding
|
||||
|
||||
The reference converts float video to uint8 via `.to(torch.uint8)` (truncation), but diffusers' `encode_video` may use `(video * 255).round().astype("uint8")` (rounding). This causes 1 pixel diff per channel at ~50% of pixels. Fix: use truncation (`.astype("uint8")`) to match the reference.
|
||||
|
||||
@@ -434,6 +434,9 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
"FluxModularPipeline",
|
||||
"LTX2AutoBlocks",
|
||||
"LTX2Blocks",
|
||||
"LTX2ModularPipeline",
|
||||
"HeliosAutoBlocks",
|
||||
"HeliosModularPipeline",
|
||||
"HeliosPyramidAutoBlocks",
|
||||
@@ -1195,6 +1198,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextModularPipeline,
|
||||
FluxModularPipeline,
|
||||
LTX2AutoBlocks,
|
||||
LTX2Blocks,
|
||||
LTX2ModularPipeline,
|
||||
HeliosAutoBlocks,
|
||||
HeliosModularPipeline,
|
||||
HeliosPyramidAutoBlocks,
|
||||
|
||||
@@ -550,19 +550,9 @@ class RMSNorm(nn.Module):
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
input_dtype = hidden_states.dtype
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
||||
|
||||
if self.weight is not None:
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
hidden_states = hidden_states * self.weight
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
else:
|
||||
hidden_states = hidden_states.to(input_dtype)
|
||||
hidden_states = torch.nn.functional.rms_norm(hidden_states, self.dim, self.weight, self.eps)
|
||||
if self.bias is not None:
|
||||
hidden_states = hidden_states + self.bias
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -37,16 +37,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
def apply_interleaved_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
cos, sin = cos.to(x.dtype), sin.to(x.dtype)
|
||||
x_real, x_imag = x.unflatten(2, (-1, 2)).unbind(-1) # [B, S, C // 2]
|
||||
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(2)
|
||||
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
||||
out = x * cos + x_rotated * sin
|
||||
return out
|
||||
|
||||
|
||||
def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
|
||||
cos, sin = freqs
|
||||
|
||||
x_dtype = x.dtype
|
||||
needs_reshape = False
|
||||
if x.ndim != 4 and cos.ndim == 4:
|
||||
# cos is (b, h, t, r) -> reshape x to (b, h, t, dim_per_head)
|
||||
@@ -61,12 +61,12 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Ten
|
||||
r = last // 2
|
||||
|
||||
# (..., 2, r)
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r).float() # Explicitly upcast to float
|
||||
split_x = x.reshape(*x.shape[:-1], 2, r)
|
||||
first_x = split_x[..., :1, :] # (..., 1, r)
|
||||
second_x = split_x[..., 1:, :] # (..., 1, r)
|
||||
|
||||
cos_u = cos.unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
|
||||
sin_u = sin.unsqueeze(-2)
|
||||
cos_u = cos.to(x.dtype).unsqueeze(-2) # broadcast to (..., 1, r) against (..., 2, r)
|
||||
sin_u = sin.to(x.dtype).unsqueeze(-2)
|
||||
|
||||
out = split_x * cos_u
|
||||
first_out = out[..., :1, :]
|
||||
@@ -80,7 +80,6 @@ def apply_split_rotary_emb(x: torch.Tensor, freqs: tuple[torch.Tensor, torch.Ten
|
||||
if needs_reshape:
|
||||
out = out.swapaxes(1, 2).reshape(b, t, -1)
|
||||
|
||||
out = out.to(dtype=x_dtype)
|
||||
return out
|
||||
|
||||
|
||||
@@ -1492,7 +1491,9 @@ class LTX2VideoTransformer3DModel(
|
||||
temb_prompt = temb_prompt_audio = None
|
||||
|
||||
# 3.2. Prepare global modality cross attention modulation parameters
|
||||
video_ca_timestep = audio_sigma.flatten() if use_cross_timestep else timestep.flatten()
|
||||
# Reference always uses cross-modality sigma for cross-attention timestep:
|
||||
# video cross-attn uses audio_sigma, audio cross-attn uses sigma (video sigma).
|
||||
video_ca_timestep = audio_sigma.flatten()
|
||||
video_cross_attn_scale_shift, _ = self.av_cross_attn_video_scale_shift(
|
||||
video_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
@@ -1508,7 +1509,7 @@ class LTX2VideoTransformer3DModel(
|
||||
)
|
||||
video_cross_attn_a2v_gate = video_cross_attn_a2v_gate.view(batch_size, -1, video_cross_attn_a2v_gate.shape[-1])
|
||||
|
||||
audio_ca_timestep = sigma.flatten() if use_cross_timestep else audio_timestep.flatten()
|
||||
audio_ca_timestep = sigma.flatten()
|
||||
audio_cross_attn_scale_shift, _ = self.av_cross_attn_audio_scale_shift(
|
||||
audio_ca_timestep,
|
||||
batch_size=batch_size,
|
||||
|
||||
@@ -70,6 +70,11 @@ else:
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextModularPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = [
|
||||
"LTX2AutoBlocks",
|
||||
"LTX2Blocks",
|
||||
"LTX2ModularPipeline",
|
||||
]
|
||||
_import_structure["flux2"] = [
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2KleinAutoBlocks",
|
||||
@@ -103,6 +108,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .components_manager import ComponentsManager
|
||||
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
|
||||
from .ltx2 import LTX2AutoBlocks, LTX2Blocks, LTX2ModularPipeline
|
||||
from .flux2 import (
|
||||
Flux2AutoBlocks,
|
||||
Flux2KleinAutoBlocks,
|
||||
|
||||
52
src/diffusers/modular_pipelines/ltx2/__init__.py
Normal file
52
src/diffusers/modular_pipelines/ltx2/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_ltx2"] = ["LTX2Blocks", "LTX2AutoBlocks", "LTX2Stage1Blocks", "LTX2Stage2Blocks", "LTX2FullPipelineBlocks"]
|
||||
_import_structure["modular_blocks_ltx2_upsample"] = ["LTX2UpsampleBlocks", "LTX2UpsampleCoreBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"LTX2ModularPipeline",
|
||||
"LTX2UpsampleModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_ltx2 import LTX2AutoBlocks, LTX2Blocks, LTX2FullPipelineBlocks, LTX2Stage1Blocks, LTX2Stage2Blocks
|
||||
from .modular_blocks_ltx2_upsample import LTX2UpsampleBlocks, LTX2UpsampleCoreBlocks
|
||||
from .modular_pipeline import LTX2ModularPipeline, LTX2UpsampleModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
27
src/diffusers/modular_pipelines/ltx2/_checkpoint_utils.py
Normal file
27
src/diffusers/modular_pipelines/ltx2/_checkpoint_utils.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Checkpoint utilities for parity debugging. No effect when _checkpoints is None."""
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class Checkpoint:
|
||||
save: bool = False
|
||||
stop: bool = False
|
||||
load: bool = False
|
||||
data: dict = field(default_factory=dict)
|
||||
|
||||
|
||||
def _maybe_checkpoint(checkpoints, name, data):
|
||||
if not checkpoints:
|
||||
return
|
||||
ckpt = checkpoints.get(name)
|
||||
if ckpt is None:
|
||||
return
|
||||
if ckpt.save:
|
||||
ckpt.data.update({
|
||||
k: v.cpu().clone() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in data.items()
|
||||
})
|
||||
if ckpt.stop:
|
||||
raise StopIteration(name)
|
||||
657
src/diffusers/modular_pipelines/ltx2/before_denoise.py
Normal file
657
src/diffusers/modular_pipelines/ltx2/before_denoise.py
Normal file
@@ -0,0 +1,657 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from ...pipelines.ltx2.utils import STAGE_2_DISTILLED_SIGMA_VALUES
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: int | None = None,
|
||||
device: str | torch.device | None = None,
|
||||
timesteps: list[int] | None = None,
|
||||
sigmas: list[float] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed.")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_num_frames, patch_size_t, post_patch_height, patch_size, post_patch_width, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
|
||||
def _pack_audio_latents(
|
||||
latents: torch.Tensor, patch_size: int | None = None, patch_size_t: int | None = None
|
||||
) -> torch.Tensor:
|
||||
if patch_size is not None and patch_size_t is not None:
|
||||
batch_size, num_channels, latent_length, latent_mel_bins = latents.shape
|
||||
post_patch_latent_length = latent_length / patch_size_t
|
||||
post_patch_mel_bins = latent_mel_bins / patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_latent_length, patch_size_t, post_patch_mel_bins, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 1, 3, 5).flatten(3, 5).flatten(1, 2)
|
||||
else:
|
||||
latents = latents.transpose(1, 2).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _normalize_audio_latents(latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor):
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents - latents_mean) / latents_std
|
||||
|
||||
|
||||
class LTX2InputStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Input processing step that determines batch_size and dtype, "
|
||||
"and expands embeddings for num_videos_per_prompt"
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_videos_per_prompt", default=1),
|
||||
InputParam("connector_prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_audio_prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_attention_mask", required=True, type_hint=torch.Tensor),
|
||||
InputParam("connector_negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("connector_audio_negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("connector_negative_attention_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("batch_size", type_hint=int),
|
||||
OutputParam("dtype", type_hint=torch.dtype),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
block_state.batch_size = block_state.connector_prompt_embeds.shape[0]
|
||||
block_state.dtype = components.transformer.dtype
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2SetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that sets up the scheduler timesteps for both video and audio denoising"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=40),
|
||||
InputParam("timesteps_input"),
|
||||
InputParam("sigmas"),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("timesteps", type_hint=torch.Tensor),
|
||||
OutputParam("num_inference_steps", type_hint=int),
|
||||
OutputParam("audio_scheduler"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
num_inference_steps = block_state.num_inference_steps
|
||||
sigmas = block_state.sigmas
|
||||
timesteps_input = block_state.timesteps_input
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
|
||||
if sigmas is None:
|
||||
# Use torch.linspace (float32) to match reference scheduler precision.
|
||||
# np.linspace computes in float64 then casts to float32, which produces
|
||||
# values that differ by 1 ULP from torch's native float32 computation.
|
||||
sigmas = torch.linspace(1.0, 0.0, num_inference_steps + 1)[:-1].numpy()
|
||||
|
||||
mu = calculate_shift(
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_image_seq_len", 1024),
|
||||
components.scheduler.config.get("max_image_seq_len", 4096),
|
||||
components.scheduler.config.get("base_shift", 0.95),
|
||||
components.scheduler.config.get("max_shift", 2.05),
|
||||
)
|
||||
|
||||
audio_scheduler = copy.deepcopy(components.scheduler)
|
||||
_, _ = retrieve_timesteps(
|
||||
audio_scheduler, num_inference_steps, device, timesteps_input, sigmas=sigmas, mu=mu
|
||||
)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler, num_inference_steps, device, timesteps_input, sigmas=sigmas, mu=mu
|
||||
)
|
||||
|
||||
block_state.timesteps = timesteps
|
||||
block_state.num_inference_steps = num_inference_steps
|
||||
block_state.audio_scheduler = audio_scheduler
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare video latents, optionally applying conditioning mask for I2V/conditional generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("noise_scale", default=1.0, type_hint=float),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
InputParam("condition_latents", type_hint=list),
|
||||
InputParam("condition_strengths", type_hint=list),
|
||||
InputParam("condition_indices", type_hint=list),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor),
|
||||
OutputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
OutputParam("clean_latents", type_hint=torch.Tensor),
|
||||
OutputParam("latent_num_frames", type_hint=int),
|
||||
OutputParam("latent_height", type_hint=int),
|
||||
OutputParam("latent_width", type_hint=int),
|
||||
OutputParam("video_sequence_length", type_hint=int),
|
||||
OutputParam("transformer_spatial_patch_size", type_hint=int),
|
||||
OutputParam("transformer_temporal_patch_size", type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
noise_scale = block_state.noise_scale
|
||||
generator = block_state.generator
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
transformer_spatial_patch_size = components.transformer.config.patch_size
|
||||
transformer_temporal_patch_size = components.transformer.config.patch_size_t
|
||||
num_channels_latents = components.transformer.config.in_channels
|
||||
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
|
||||
condition_latents = getattr(block_state, "condition_latents", None) or []
|
||||
condition_strengths = getattr(block_state, "condition_strengths", None) or []
|
||||
condition_indices = getattr(block_state, "condition_indices", None) or []
|
||||
has_conditions = len(condition_latents) > 0
|
||||
|
||||
if block_state.latents is not None:
|
||||
latents = block_state.latents
|
||||
if latents.ndim == 5:
|
||||
latents = _normalize_latents(
|
||||
latents, components.vae.latents_mean, components.vae.latents_std, components.vae.config.scaling_factor
|
||||
)
|
||||
_, _, latent_num_frames, latent_height, latent_width = latents.shape
|
||||
latents = _pack_latents(latents, transformer_spatial_patch_size, transformer_temporal_patch_size)
|
||||
else:
|
||||
# Reference: create zeros in [B,C,F,H,W] in model dtype, pack to [B,S,D],
|
||||
# then generate noise in packed shape with same dtype
|
||||
latent_dtype = components.transformer.dtype
|
||||
shape = (batch_size, num_channels_latents, latent_num_frames, latent_height, latent_width)
|
||||
latents = torch.zeros(shape, device=device, dtype=latent_dtype)
|
||||
latents = _pack_latents(latents, transformer_spatial_patch_size, transformer_temporal_patch_size)
|
||||
|
||||
conditioning_mask = None
|
||||
clean_latents = None
|
||||
|
||||
if has_conditions:
|
||||
mask_shape = (batch_size, 1, latent_num_frames, latent_height, latent_width)
|
||||
conditioning_mask = torch.zeros(mask_shape, device=device, dtype=torch.float32)
|
||||
conditioning_mask = _pack_latents(
|
||||
conditioning_mask, transformer_spatial_patch_size, transformer_temporal_patch_size
|
||||
)
|
||||
|
||||
clean_latents = torch.zeros_like(latents)
|
||||
for cond, strength, latent_idx in zip(condition_latents, condition_strengths, condition_indices):
|
||||
num_cond_tokens = cond.size(1)
|
||||
start_token_idx = latent_idx * latent_height * latent_width
|
||||
end_token_idx = start_token_idx + num_cond_tokens
|
||||
|
||||
latents[:, start_token_idx:end_token_idx] = cond
|
||||
conditioning_mask[:, start_token_idx:end_token_idx] = strength
|
||||
clean_latents[:, start_token_idx:end_token_idx] = cond
|
||||
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
|
||||
# Noise in packed [B,S,D] shape and same dtype as latent (matches reference GaussianNoiser)
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=latents.device, dtype=latents.dtype)
|
||||
scaled_mask = (1.0 - conditioning_mask) * noise_scale
|
||||
latents = noise * scaled_mask + latents * (1 - scaled_mask)
|
||||
else:
|
||||
# T2V: noise in packed shape, same dtype as latent
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
scaled_mask = noise_scale
|
||||
latents = noise * scaled_mask + latents * (1 - scaled_mask)
|
||||
|
||||
block_state.latents = latents
|
||||
block_state.conditioning_mask = conditioning_mask
|
||||
block_state.clean_latents = clean_latents
|
||||
block_state.latent_num_frames = latent_num_frames
|
||||
block_state.latent_height = latent_height
|
||||
block_state.latent_width = latent_width
|
||||
block_state.video_sequence_length = latent_num_frames * latent_height * latent_width
|
||||
block_state.transformer_spatial_patch_size = transformer_spatial_patch_size
|
||||
block_state.transformer_temporal_patch_size = transformer_temporal_patch_size
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareAudioLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare audio latents for the denoising process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("audio_vae", AutoencoderKLLTX2Audio),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
InputParam("noise_scale", default=1.0, type_hint=float),
|
||||
InputParam("audio_latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("audio_latents", type_hint=torch.Tensor),
|
||||
OutputParam("audio_num_frames", type_hint=int),
|
||||
OutputParam("latent_mel_bins", type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
num_frames = block_state.num_frames
|
||||
frame_rate = block_state.frame_rate
|
||||
noise_scale = block_state.noise_scale
|
||||
generator = block_state.generator
|
||||
batch_size = block_state.batch_size * block_state.num_videos_per_prompt
|
||||
|
||||
audio_sampling_rate = components.audio_vae.config.sample_rate
|
||||
audio_hop_length = components.audio_vae.config.mel_hop_length
|
||||
audio_vae_temporal_compression_ratio = components.audio_vae.temporal_compression_ratio
|
||||
audio_vae_mel_compression_ratio = components.audio_vae.mel_compression_ratio
|
||||
|
||||
duration_s = num_frames / frame_rate
|
||||
audio_latents_per_second = audio_sampling_rate / audio_hop_length / float(audio_vae_temporal_compression_ratio)
|
||||
audio_num_frames = round(duration_s * audio_latents_per_second)
|
||||
|
||||
num_mel_bins = components.audio_vae.config.mel_bins
|
||||
latent_mel_bins = num_mel_bins // audio_vae_mel_compression_ratio
|
||||
num_channels_latents_audio = components.audio_vae.config.latent_channels
|
||||
|
||||
if block_state.audio_latents is not None:
|
||||
audio_latents = block_state.audio_latents
|
||||
if audio_latents.ndim == 4:
|
||||
_, _, audio_num_frames, _ = audio_latents.shape
|
||||
audio_latents = _pack_audio_latents(audio_latents)
|
||||
audio_latents = _normalize_audio_latents(
|
||||
audio_latents, components.audio_vae.latents_mean, components.audio_vae.latents_std
|
||||
)
|
||||
if noise_scale > 0.0:
|
||||
noise = randn_tensor(
|
||||
audio_latents.shape, generator=generator, device=audio_latents.device, dtype=audio_latents.dtype
|
||||
)
|
||||
audio_latents = noise_scale * noise + (1 - noise_scale) * audio_latents
|
||||
elif audio_latents.ndim == 3 and noise_scale > 0.0:
|
||||
noise = randn_tensor(
|
||||
audio_latents.shape, generator=generator, device=audio_latents.device, dtype=audio_latents.dtype
|
||||
)
|
||||
audio_latents = noise_scale * noise + (1 - noise_scale) * audio_latents
|
||||
audio_latents = audio_latents.to(device=device, dtype=torch.float32)
|
||||
else:
|
||||
# Reference: create zeros in [B,C,T,M] in model dtype, pack, then noise in packed shape
|
||||
latent_dtype = components.audio_vae.dtype
|
||||
shape = (batch_size, num_channels_latents_audio, audio_num_frames, latent_mel_bins)
|
||||
audio_latents = torch.zeros(shape, device=device, dtype=latent_dtype)
|
||||
audio_latents = _pack_audio_latents(audio_latents)
|
||||
if isinstance(generator, list):
|
||||
generator = generator[0]
|
||||
noise = randn_tensor(audio_latents.shape, generator=generator, device=device, dtype=latent_dtype)
|
||||
audio_latents = noise * noise_scale + audio_latents * (1 - noise_scale)
|
||||
|
||||
block_state.audio_latents = audio_latents
|
||||
block_state.audio_num_frames = audio_num_frames
|
||||
block_state.latent_mel_bins = latent_mel_bins
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2PrepareCoordinatesStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare video and audio RoPE coordinates for positional encoding"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("video_coords", type_hint=torch.Tensor),
|
||||
OutputParam("audio_coords", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
audio_latents = block_state.audio_latents
|
||||
frame_rate = block_state.frame_rate
|
||||
|
||||
video_coords = components.transformer.rope.prepare_video_coords(
|
||||
latents.shape[0],
|
||||
block_state.latent_num_frames,
|
||||
block_state.latent_height,
|
||||
block_state.latent_width,
|
||||
latents.device,
|
||||
fps=frame_rate,
|
||||
)
|
||||
# Cast to latent dtype to match reference (positions stored in model dtype)
|
||||
video_coords = video_coords.to(latents.dtype)
|
||||
audio_coords = components.transformer.audio_rope.prepare_audio_coords(
|
||||
audio_latents.shape[0], block_state.audio_num_frames, audio_latents.device
|
||||
)
|
||||
# Note: audio_coords already match reference dtype, no cast needed
|
||||
|
||||
block_state.video_coords = video_coords
|
||||
block_state.audio_coords = audio_coords
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2Stage2SetTimestepsStep(LTX2SetTimestepsStep):
|
||||
"""SetTimesteps for Stage 2: fixed distilled sigmas, no dynamic shifting."""
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Stage 2 timestep setup: uses fixed distilled sigmas with dynamic shifting disabled"
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=3),
|
||||
InputParam("timesteps_input"),
|
||||
InputParam("sigmas", default=list(STAGE_2_DISTILLED_SIGMA_VALUES)),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
components.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
|
||||
components.scheduler.config,
|
||||
use_dynamic_shifting=False,
|
||||
shift_terminal=None,
|
||||
)
|
||||
return super().__call__(components, state)
|
||||
|
||||
|
||||
class LTX2Stage2PrepareLatentsStep(LTX2PrepareLatentsStep):
|
||||
"""PrepareLatents for Stage 2: noise_scale defaults to first distilled sigma value."""
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("noise_scale", default=STAGE_2_DISTILLED_SIGMA_VALUES[0], type_hint=float),
|
||||
InputParam("latents", type_hint=torch.Tensor),
|
||||
InputParam("generator"),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("num_videos_per_prompt", default=1, type_hint=int),
|
||||
InputParam("condition_latents", type_hint=list),
|
||||
InputParam("condition_strengths", type_hint=list),
|
||||
InputParam("condition_indices", type_hint=list),
|
||||
]
|
||||
|
||||
|
||||
class LTX2DisableAdapterStep(ModularPipelineBlocks):
|
||||
"""Disables LoRA adapters on transformer and connectors. No-op if no adapters are loaded."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Disable LoRA adapters before stage 1 (no-op if no adapters loaded)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return []
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
for model in [components.transformer, components.connectors]:
|
||||
if getattr(model, "_hf_peft_config_loaded", False):
|
||||
model.disable_adapters()
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2EnableAdapterStep(ModularPipelineBlocks):
|
||||
"""Enables LoRA adapters by name before stage 2. No-op if stage_2_adapter is not provided."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Enable LoRA adapters before stage 2 (no-op if stage_2_adapter not provided)"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("stage_2_adapter", type_hint=str, description="Name of the LoRA adapter to enable for stage 2"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
adapter_name = block_state.stage_2_adapter
|
||||
if adapter_name is not None:
|
||||
for model in [components.transformer, components.connectors]:
|
||||
if getattr(model, "_hf_peft_config_loaded", False):
|
||||
model.enable_adapters()
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
228
src/diffusers/modular_pipelines/ltx2/decoders.py
Normal file
228
src/diffusers/modular_pipelines/ltx2/decoders.py
Normal file
@@ -0,0 +1,228 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Audio, AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.vocoder import LTX2Vocoder
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
|
||||
def _unpack_audio_latents(
|
||||
latents: torch.Tensor,
|
||||
latent_length: int,
|
||||
num_mel_bins: int,
|
||||
patch_size: int | None = None,
|
||||
patch_size_t: int | None = None,
|
||||
) -> torch.Tensor:
|
||||
if patch_size is not None and patch_size_t is not None:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, latent_length, num_mel_bins, -1, patch_size_t, patch_size)
|
||||
latents = latents.permute(0, 3, 1, 4, 2, 5).flatten(4, 5).flatten(2, 3)
|
||||
else:
|
||||
latents = latents.unflatten(2, (-1, num_mel_bins)).transpose(1, 2)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_audio_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.to(latents.device, latents.dtype)
|
||||
return (latents * latents_std) + latents_mean
|
||||
|
||||
|
||||
class LTX2VideoDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised video latents into video frames"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised video latents"),
|
||||
InputParam("output_type", default="np", type_hint=str, description="Output format: pil, np, pt, latent"),
|
||||
InputParam("decode_timestep", default=0.0, description="Timestep for VAE decode conditioning"),
|
||||
InputParam("decode_noise_scale", default=None, description="Noise scale for decode conditioning"),
|
||||
InputParam("generator", description="Random generator for reproducibility"),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("batch_size", required=True, type_hint=int),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
InputParam("transformer_spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("transformer_temporal_patch_size", default=1, type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("videos", description="The decoded video frames"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
# Unpack latents from [B, S, D] -> [B, C, F, H, W]
|
||||
# Uses the transformer's patchify sizes (not the VAE's internal patch_size)
|
||||
latents = _unpack_latents(
|
||||
latents,
|
||||
block_state.latent_num_frames,
|
||||
block_state.latent_height,
|
||||
block_state.latent_width,
|
||||
block_state.transformer_spatial_patch_size,
|
||||
block_state.transformer_temporal_patch_size,
|
||||
)
|
||||
# Denormalize
|
||||
latents = _denormalize_latents(
|
||||
latents, components.vae.latents_mean, components.vae.latents_std, components.vae.config.scaling_factor
|
||||
)
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.videos = latents
|
||||
else:
|
||||
latents = latents.to(block_state.dtype)
|
||||
device = latents.device
|
||||
|
||||
if not components.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(
|
||||
latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype
|
||||
)
|
||||
decode_timestep = block_state.decode_timestep
|
||||
decode_noise_scale = block_state.decode_noise_scale
|
||||
batch_size = block_state.batch_size
|
||||
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
latents = latents.to(components.vae.dtype)
|
||||
video = components.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
video, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2AudioDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Step that decodes the denoised audio latents into audio waveforms"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("audio_vae", AutoencoderKLLTX2Audio),
|
||||
ComponentSpec("vocoder", LTX2Vocoder),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor, description="Denoised audio latents"),
|
||||
InputParam("output_type", default="np", type_hint=str),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_mel_bins", required=True, type_hint=int),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("audio", description="The decoded audio waveforms"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
audio_latents = block_state.audio_latents
|
||||
|
||||
# Denormalize audio latents
|
||||
audio_latents = _denormalize_audio_latents(
|
||||
audio_latents, components.audio_vae.latents_mean, components.audio_vae.latents_std
|
||||
)
|
||||
# Unpack audio latents
|
||||
audio_latents = _unpack_audio_latents(
|
||||
audio_latents, block_state.audio_num_frames, num_mel_bins=block_state.latent_mel_bins
|
||||
)
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.audio = audio_latents
|
||||
else:
|
||||
audio_latents = audio_latents.to(components.audio_vae.dtype)
|
||||
generated_mel_spectrograms = components.audio_vae.decode(audio_latents, return_dict=False)[0]
|
||||
# Squeeze batch dim and cast to float32 to match reference's decode_audio output format
|
||||
block_state.audio = components.vocoder(generated_mel_spectrograms).squeeze(0).float()
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
490
src/diffusers/modular_pipelines/ltx2/denoise.py
Normal file
490
src/diffusers/modular_pipelines/ltx2/denoise.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models.transformers import LTX2VideoTransformer3DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import (
|
||||
BlockState,
|
||||
LoopSequentialPipelineBlocks,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LTX2LoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that prepares the latent inputs for the denoiser, "
|
||||
"including timestep masking for conditioned frames."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("dtype", required=True, type_hint=torch.dtype),
|
||||
InputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
block_state.latent_model_input = block_state.latents.to(block_state.dtype)
|
||||
block_state.audio_latent_model_input = block_state.audio_latents.to(block_state.dtype)
|
||||
|
||||
batch_size = block_state.latent_model_input.shape[0]
|
||||
num_video_tokens = block_state.latent_model_input.shape[1]
|
||||
num_audio_tokens = block_state.audio_latent_model_input.shape[1]
|
||||
|
||||
video_timestep = t.expand(batch_size, num_video_tokens)
|
||||
|
||||
if block_state.conditioning_mask is not None:
|
||||
block_state.video_timestep = video_timestep * (
|
||||
1 - block_state.conditioning_mask.squeeze(-1)
|
||||
)
|
||||
else:
|
||||
block_state.video_timestep = video_timestep
|
||||
|
||||
block_state.audio_timestep = t.expand(batch_size, num_audio_tokens)
|
||||
# Sigma for prompt_adaln: f32 to match reference's f32(sigma * scale_multiplier)
|
||||
block_state.sigma = torch.tensor([t.item()], dtype=torch.float32)
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTX2LoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
guider_input_fields: dict[str, Any] = None,
|
||||
guider_name: str = "guider",
|
||||
guider_config: FrozenDict = None,
|
||||
):
|
||||
"""Initialize a denoiser block for LTX2 that handles dual video+audio outputs.
|
||||
|
||||
Args:
|
||||
guider_input_fields: Dictionary mapping transformer argument names to block_state field names.
|
||||
Values can be tuples (conditional, unconditional) or strings (same for both).
|
||||
guider_name: Name of the guider component to use (default: "guider").
|
||||
guider_config: Config for the guider component (default: guidance_scale=4.0).
|
||||
"""
|
||||
self._guider_name = guider_name
|
||||
if guider_config is None:
|
||||
guider_config = FrozenDict({"guidance_scale": 4.0})
|
||||
self._guider_config = guider_config
|
||||
if guider_input_fields is None:
|
||||
guider_input_fields = {
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
}
|
||||
if not isinstance(guider_input_fields, dict):
|
||||
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
|
||||
self._guider_input_fields = guider_input_fields
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec(
|
||||
self._guider_name,
|
||||
ClassifierFreeGuidance,
|
||||
config=self._guider_config,
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
ComponentSpec("transformer", LTX2VideoTransformer3DModel),
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that runs the transformer with guidance "
|
||||
"and handles dual video+audio output splitting. CFG is applied in x0 space "
|
||||
"to match the reference implementation."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
inputs = [
|
||||
InputParam("attention_kwargs"),
|
||||
InputParam("num_inference_steps", required=True, type_hint=int),
|
||||
InputParam("latent_num_frames", required=True, type_hint=int),
|
||||
InputParam("latent_height", required=True, type_hint=int),
|
||||
InputParam("latent_width", required=True, type_hint=int),
|
||||
InputParam("audio_num_frames", required=True, type_hint=int),
|
||||
InputParam("frame_rate", default=24.0, type_hint=float),
|
||||
InputParam("video_coords", required=True, type_hint=torch.Tensor),
|
||||
InputParam("audio_coords", required=True, type_hint=torch.Tensor),
|
||||
InputParam("guidance_rescale", default=0.0, type_hint=float),
|
||||
InputParam("sigma", type_hint=torch.Tensor),
|
||||
]
|
||||
guider_input_names = []
|
||||
for value in self._guider_input_fields.values():
|
||||
if isinstance(value, tuple):
|
||||
guider_input_names.extend(value)
|
||||
else:
|
||||
guider_input_names.append(value)
|
||||
|
||||
for name in set(guider_input_names):
|
||||
inputs.append(InputParam(name=name, type_hint=torch.Tensor))
|
||||
return inputs
|
||||
|
||||
@staticmethod
|
||||
def _convert_velocity_to_x0(sample, velocity, sigma):
|
||||
return sample - velocity * sigma
|
||||
|
||||
@staticmethod
|
||||
def _convert_x0_to_velocity(sample, x0, sigma):
|
||||
return (sample - x0) / sigma
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
guider = getattr(components, self._guider_name)
|
||||
guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
|
||||
guider_state = guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
|
||||
|
||||
use_cross_timestep = getattr(components.transformer.config, "use_cross_timestep", False)
|
||||
sigma_val = components.scheduler.sigmas[i]
|
||||
|
||||
# Pass raw sigma to wrapper if available (avoids timestep/1000 round-trip precision loss)
|
||||
if hasattr(components.transformer, "_raw_sigma"):
|
||||
components.transformer._raw_sigma = sigma_val
|
||||
|
||||
for guider_state_batch in guider_state:
|
||||
guider.prepare_models(components.transformer)
|
||||
cond_kwargs = guider_state_batch.as_dict()
|
||||
cond_kwargs = {
|
||||
k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in cond_kwargs.items()
|
||||
if k in self._guider_input_fields.keys()
|
||||
}
|
||||
# Drop all-ones attention masks — they're functionally no-op but trigger
|
||||
# a different SDPA kernel path (masked vs unmasked) with different bf16 rounding.
|
||||
# Reference passes context_mask=None for unmasked attention.
|
||||
for mask_key in ["encoder_attention_mask", "audio_encoder_attention_mask"]:
|
||||
mask = cond_kwargs.get(mask_key)
|
||||
if mask is not None and mask.ndim <= 2 and (mask == 1).all():
|
||||
cond_kwargs[mask_key] = None
|
||||
|
||||
video_timestep = block_state.video_timestep
|
||||
audio_timestep = block_state.audio_timestep
|
||||
|
||||
with components.transformer.cache_context("cond_uncond"):
|
||||
noise_pred_video, noise_pred_audio = components.transformer(
|
||||
hidden_states=block_state.latent_model_input.to(block_state.dtype),
|
||||
audio_hidden_states=block_state.audio_latent_model_input.to(block_state.dtype),
|
||||
timestep=video_timestep,
|
||||
audio_timestep=audio_timestep,
|
||||
sigma=block_state.sigma,
|
||||
num_frames=block_state.latent_num_frames,
|
||||
height=block_state.latent_height,
|
||||
width=block_state.latent_width,
|
||||
fps=block_state.frame_rate,
|
||||
audio_num_frames=block_state.audio_num_frames,
|
||||
video_coords=block_state.video_coords,
|
||||
audio_coords=block_state.audio_coords,
|
||||
use_cross_timestep=use_cross_timestep,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
)
|
||||
|
||||
# Convert to x0 for guidance.
|
||||
prediction_type = getattr(components.transformer, "prediction_type", "velocity")
|
||||
if prediction_type == "x0":
|
||||
# Model already outputs x0 — no conversion needed
|
||||
x0_video = noise_pred_video
|
||||
x0_audio = noise_pred_audio
|
||||
else:
|
||||
# Model outputs velocity — convert to x0 matching reference's to_denoised:
|
||||
# (sample.f32 - velocity.f32 * sigma_f32).to(sample.dtype)
|
||||
# Reference uses f32 sigma (from denoise_mask * sigma, both f32).
|
||||
x0_video = self._convert_velocity_to_x0(
|
||||
block_state.latents.float(), noise_pred_video.float(), sigma_val
|
||||
).to(block_state.latents.dtype)
|
||||
x0_audio = self._convert_velocity_to_x0(
|
||||
block_state.audio_latents.float(), noise_pred_audio.float(), sigma_val
|
||||
).to(block_state.audio_latents.dtype)
|
||||
|
||||
guider_state_batch.noise_pred = x0_video
|
||||
guider_state_batch.noise_pred_audio = x0_audio
|
||||
|
||||
# Sub-step checkpoint: save/load x0 per condition
|
||||
_ckpts = getattr(block_state, "_checkpoints", None)
|
||||
if _ckpts:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
cond_label = "cond" if guider_state_batch is guider_state[0] else "uncond"
|
||||
_maybe_checkpoint(_ckpts, f"step_{i}_{cond_label}_x0", {
|
||||
"video": x0_video, "audio": x0_audio,
|
||||
})
|
||||
# Load support: inject reference x0 for this condition
|
||||
ckpt = _ckpts.get(f"step_{i}_{cond_label}_x0")
|
||||
if ckpt is not None and ckpt.load:
|
||||
x0_video = ckpt.data["video"].to(x0_video)
|
||||
x0_audio = ckpt.data["audio"].to(x0_audio)
|
||||
guider_state_batch.noise_pred = x0_video
|
||||
guider_state_batch.noise_pred_audio = x0_audio
|
||||
|
||||
guider.cleanup_models(components.transformer)
|
||||
|
||||
# Apply guidance in x0 space using reference formula:
|
||||
# cond + (scale - 1) * (cond - uncond)
|
||||
# This is mathematically equivalent to uncond + scale * (cond - uncond)
|
||||
# but produces different bf16 rounding.
|
||||
if len(guider_state) == 2:
|
||||
guidance_scale = guider.guidance_scale
|
||||
x0_video_cond = guider_state[0].noise_pred
|
||||
x0_video_uncond = guider_state[1].noise_pred
|
||||
guided_x0_video = x0_video_cond + (guidance_scale - 1) * (x0_video_cond - x0_video_uncond)
|
||||
|
||||
x0_audio_cond = guider_state[0].noise_pred_audio
|
||||
x0_audio_uncond = guider_state[1].noise_pred_audio
|
||||
guided_x0_audio = x0_audio_cond + (guidance_scale - 1) * (x0_audio_cond - x0_audio_uncond)
|
||||
|
||||
if block_state.guidance_rescale > 0:
|
||||
guided_x0_video = self._rescale_noise_cfg(
|
||||
guided_x0_video,
|
||||
guider_state[0].noise_pred,
|
||||
block_state.guidance_rescale,
|
||||
)
|
||||
guided_x0_audio = self._rescale_noise_cfg(
|
||||
guided_x0_audio,
|
||||
x0_audio_cond,
|
||||
block_state.guidance_rescale,
|
||||
)
|
||||
else:
|
||||
guided_x0_video = guider_state[0].noise_pred
|
||||
guided_x0_audio = guider_state[0].noise_pred_audio
|
||||
|
||||
# Sub-step checkpoint: save/load guided x0
|
||||
_ckpts = getattr(block_state, "_checkpoints", None)
|
||||
if _ckpts:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
_maybe_checkpoint(_ckpts, f"step_{i}_guided_x0", {
|
||||
"video": guided_x0_video, "audio": guided_x0_audio,
|
||||
})
|
||||
# Load support: inject reference guided x0
|
||||
ckpt = _ckpts.get(f"step_{i}_guided_x0")
|
||||
if ckpt is not None and ckpt.load:
|
||||
guided_x0_video = ckpt.data["video"].to(guided_x0_video)
|
||||
guided_x0_audio = ckpt.data["audio"].to(guided_x0_audio)
|
||||
|
||||
# Convert guided x0 back to velocity for the scheduler.
|
||||
# Use sigma_val.item() (Python float) to match reference's to_velocity which
|
||||
# does sigma.to(float32).item() — dividing by Python float vs 0-dim tensor
|
||||
# uses different CUDA kernels and can produce different results at specific values.
|
||||
sigma_scalar = sigma_val.item()
|
||||
block_state.noise_pred_video = self._convert_x0_to_velocity(
|
||||
block_state.latents.float(), guided_x0_video, sigma_scalar
|
||||
).to(block_state.latents.dtype)
|
||||
block_state.noise_pred_audio = self._convert_x0_to_velocity(
|
||||
block_state.audio_latents.float(), guided_x0_audio, sigma_scalar
|
||||
).to(block_state.audio_latents.dtype)
|
||||
|
||||
return components, block_state
|
||||
|
||||
@staticmethod
|
||||
def _rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
|
||||
|
||||
class LTX2LoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop that updates latents via scheduler step, "
|
||||
"with optional x0-space conditioning blending."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("conditioning_mask", type_hint=torch.Tensor),
|
||||
InputParam("clean_latents", type_hint=torch.Tensor),
|
||||
InputParam("audio_scheduler", required=True),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, block_state: BlockState, i: int, t: torch.Tensor):
|
||||
noise_pred_video = block_state.noise_pred_video
|
||||
noise_pred_audio = block_state.noise_pred_audio
|
||||
|
||||
if block_state.conditioning_mask is not None:
|
||||
# x0 blending: convert velocity to x0, blend with clean latents, convert back
|
||||
sigma = components.scheduler.sigmas[i]
|
||||
denoised_sample = block_state.latents - noise_pred_video * sigma
|
||||
|
||||
bsz = noise_pred_video.size(0)
|
||||
conditioning_mask = block_state.conditioning_mask[:bsz]
|
||||
clean_latents = block_state.clean_latents
|
||||
|
||||
denoised_sample_cond = (
|
||||
denoised_sample * (1 - conditioning_mask) + clean_latents.float() * conditioning_mask
|
||||
).to(noise_pred_video.dtype)
|
||||
|
||||
denoised_latents_cond = ((block_state.latents - denoised_sample_cond) / sigma).to(
|
||||
noise_pred_video.dtype
|
||||
)
|
||||
block_state.latents = components.scheduler.step(
|
||||
denoised_latents_cond, t, block_state.latents, return_dict=False
|
||||
)[0]
|
||||
else:
|
||||
block_state.latents = components.scheduler.step(
|
||||
noise_pred_video, t, block_state.latents, return_dict=False
|
||||
)[0]
|
||||
|
||||
block_state.audio_latents = block_state.audio_scheduler.step(
|
||||
noise_pred_audio, t, block_state.audio_latents, return_dict=False
|
||||
)[0]
|
||||
|
||||
return components, block_state
|
||||
|
||||
|
||||
class LTX2DenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Pipeline block that iteratively denoises the latents over timesteps for LTX2"
|
||||
|
||||
@property
|
||||
def loop_expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def loop_inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("timesteps", required=True, type_hint=torch.Tensor),
|
||||
InputParam("num_inference_steps", required=True, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
_checkpoints = state.get("_checkpoints")
|
||||
|
||||
block_state.num_warmup_steps = max(
|
||||
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
|
||||
)
|
||||
|
||||
# Checkpoint: save/load preloop state
|
||||
if _checkpoints:
|
||||
from diffusers.modular_pipelines.ltx2._checkpoint_utils import _maybe_checkpoint
|
||||
_maybe_checkpoint(_checkpoints, "preloop", {
|
||||
"video_latent": block_state.latents, "audio_latent": block_state.audio_latents,
|
||||
})
|
||||
if "preloop" in _checkpoints and _checkpoints["preloop"].load:
|
||||
d = _checkpoints["preloop"].data
|
||||
block_state.latents = d["video_latent"].to(block_state.latents)
|
||||
block_state.audio_latents = d["audio_latent"].to(block_state.audio_latents)
|
||||
|
||||
# Pass _checkpoints to sub-blocks via block_state
|
||||
if _checkpoints:
|
||||
block_state._checkpoints = _checkpoints
|
||||
|
||||
try:
|
||||
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(block_state.timesteps):
|
||||
components, block_state = self.loop_step(components, block_state, i=i, t=t)
|
||||
|
||||
# Checkpoint: save velocity (= guided prediction) after denoiser, before scheduler
|
||||
if _checkpoints:
|
||||
_maybe_checkpoint(_checkpoints, f"step_{i}_velocity", {
|
||||
"video": block_state.noise_pred_video, "audio": block_state.noise_pred_audio,
|
||||
})
|
||||
|
||||
# Checkpoint: save/load after each step
|
||||
if _checkpoints:
|
||||
_maybe_checkpoint(_checkpoints, f"after_step_{i}", {
|
||||
"video_latent": block_state.latents, "audio_latent": block_state.audio_latents,
|
||||
})
|
||||
if f"after_step_{i}" in _checkpoints and _checkpoints[f"after_step_{i}"].load:
|
||||
d = _checkpoints[f"after_step_{i}"].data
|
||||
block_state.latents = d["video_latent"].to(block_state.latents)
|
||||
block_state.audio_latents = d["audio_latent"].to(block_state.audio_latents)
|
||||
|
||||
if i == len(block_state.timesteps) - 1 or (
|
||||
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
|
||||
):
|
||||
progress_bar.update()
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2DenoiseStep(LTX2DenoiseLoopWrapper):
|
||||
block_classes = [
|
||||
LTX2LoopBeforeDenoiser,
|
||||
LTX2LoopDenoiser(
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
}
|
||||
),
|
||||
LTX2LoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoises video and audio latents.\n"
|
||||
"At each iteration, it runs:\n"
|
||||
" - LTX2LoopBeforeDenoiser (prepare inputs, timestep masking)\n"
|
||||
" - LTX2LoopDenoiser (transformer forward + guidance)\n"
|
||||
" - LTX2LoopAfterDenoiser (scheduler step + x0 blending)\n"
|
||||
"Supports T2V, I2V, and conditional generation."
|
||||
)
|
||||
541
src/diffusers/modular_pipelines/ltx2/encoders.py
Normal file
541
src/diffusers/modular_pipelines/ltx2/encoders.py
Normal file
@@ -0,0 +1,541 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
from transformers import Gemma3ForConditionalGeneration, GemmaTokenizer, GemmaTokenizerFast
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.connectors import LTX2TextConnectors
|
||||
from ...utils import logging
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2VideoCondition:
|
||||
"""
|
||||
Defines a single frame-conditioning item for LTX-2 Video.
|
||||
|
||||
Attributes:
|
||||
frames: The image (or video) to condition on.
|
||||
index: The latent index at which to insert the condition.
|
||||
strength: The strength of the conditioning effect (0-1).
|
||||
"""
|
||||
|
||||
frames: PIL.Image.Image | list[PIL.Image.Image] | np.ndarray | torch.Tensor
|
||||
index: int = 0
|
||||
strength: float = 1.0
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _pack_text_embeds(
|
||||
text_hidden_states: torch.Tensor,
|
||||
sequence_lengths: torch.Tensor,
|
||||
device: str | torch.device,
|
||||
padding_side: str = "left",
|
||||
scale_factor: int = 8,
|
||||
eps: float = 1e-6,
|
||||
) -> torch.Tensor:
|
||||
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
|
||||
original_dtype = text_hidden_states.dtype
|
||||
|
||||
token_indices = torch.arange(seq_len, device=device).unsqueeze(0)
|
||||
if padding_side == "right":
|
||||
mask = token_indices < sequence_lengths[:, None]
|
||||
elif padding_side == "left":
|
||||
start_indices = seq_len - sequence_lengths[:, None]
|
||||
mask = token_indices >= start_indices
|
||||
else:
|
||||
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
|
||||
mask = mask[:, :, None, None]
|
||||
|
||||
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
|
||||
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
|
||||
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
|
||||
|
||||
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
|
||||
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
|
||||
|
||||
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
|
||||
normalized_hidden_states = normalized_hidden_states * scale_factor
|
||||
|
||||
normalized_hidden_states = normalized_hidden_states.flatten(2)
|
||||
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
|
||||
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
|
||||
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
|
||||
return normalized_hidden_states
|
||||
|
||||
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
|
||||
def _pack_latents(latents: torch.Tensor, patch_size: int = 1, patch_size_t: int = 1) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = latents.shape
|
||||
post_patch_num_frames = num_frames // patch_size_t
|
||||
post_patch_height = height // patch_size
|
||||
post_patch_width = width // patch_size
|
||||
latents = latents.reshape(
|
||||
batch_size, -1, post_patch_num_frames, patch_size_t, post_patch_height, patch_size, post_patch_width, patch_size
|
||||
)
|
||||
latents = latents.permute(0, 2, 4, 6, 1, 3, 5, 7).flatten(4, 7).flatten(1, 3)
|
||||
return latents
|
||||
|
||||
|
||||
class LTX2TextEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Text encoder step that encodes prompts using Gemma3 for LTX2 video generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("text_encoder", Gemma3ForConditionalGeneration),
|
||||
ComponentSpec("tokenizer", GemmaTokenizerFast),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("prompt"),
|
||||
InputParam("negative_prompt"),
|
||||
InputParam("max_sequence_length", default=1024),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Packed text embeddings from Gemma3",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Packed negative text embeddings from Gemma3",
|
||||
),
|
||||
OutputParam(
|
||||
"prompt_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask for prompt embeddings",
|
||||
),
|
||||
OutputParam(
|
||||
"negative_prompt_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask for negative prompt embeddings",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(block_state):
|
||||
if block_state.prompt is not None and (
|
||||
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
|
||||
):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
|
||||
|
||||
@staticmethod
|
||||
def _get_gemma_prompt_embeds(
|
||||
text_encoder,
|
||||
tokenizer,
|
||||
prompt: str | list[str],
|
||||
max_sequence_length: int = 1024,
|
||||
scale_factor: int = 8,
|
||||
device: torch.device | None = None,
|
||||
dtype: torch.dtype | None = None,
|
||||
):
|
||||
dtype = dtype or text_encoder.dtype
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
tokenizer.padding_side = "left"
|
||||
if tokenizer.pad_token is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
prompt = [p.strip() for p in prompt]
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_attention_mask = text_inputs.attention_mask.to(device)
|
||||
|
||||
text_encoder_outputs = text_encoder(
|
||||
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
|
||||
)
|
||||
text_encoder_hidden_states = text_encoder_outputs.hidden_states
|
||||
text_encoder_hidden_states = torch.stack(text_encoder_hidden_states, dim=-1)
|
||||
# Return raw stacked hidden states [B, T, D, L] — the connector handles normalization
|
||||
# (per_token_rms_norm + rescaling for LTX-2.3, or _pack_text_embeds for LTX-2.0)
|
||||
prompt_embeds = text_encoder_hidden_states.to(dtype=dtype)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
@staticmethod
|
||||
def encode_prompt(
|
||||
components,
|
||||
prompt: str | list[str],
|
||||
device: torch.device | None = None,
|
||||
prepare_unconditional_embeds: bool = True,
|
||||
negative_prompt: str | list[str] | None = None,
|
||||
max_sequence_length: int = 1024,
|
||||
):
|
||||
device = device or components._execution_device
|
||||
|
||||
if not isinstance(prompt, list):
|
||||
prompt = [prompt]
|
||||
batch_size = len(prompt)
|
||||
|
||||
prompt_embeds, prompt_attention_mask = LTX2TextEncoderStep._get_gemma_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
if prepare_unconditional_embeds:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = LTX2TextEncoderStep._get_gemma_prompt_embeds(
|
||||
text_encoder=components.text_encoder,
|
||||
tokenizer=components.tokenizer,
|
||||
prompt=negative_prompt,
|
||||
max_sequence_length=max_sequence_length,
|
||||
device=device,
|
||||
)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
self.check_inputs(block_state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
(
|
||||
block_state.prompt_embeds,
|
||||
block_state.prompt_attention_mask,
|
||||
block_state.negative_prompt_embeds,
|
||||
block_state.negative_prompt_attention_mask,
|
||||
) = self.encode_prompt(
|
||||
components=components,
|
||||
prompt=block_state.prompt,
|
||||
device=device,
|
||||
prepare_unconditional_embeds=components.requires_unconditional_embeds,
|
||||
negative_prompt=block_state.negative_prompt,
|
||||
max_sequence_length=block_state.max_sequence_length,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2ConnectorStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Connector step that transforms text embeddings into video and audio conditioning embeddings"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("connectors", LTX2TextConnectors),
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 4.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("prompt_embeds", required=True, type_hint=torch.Tensor),
|
||||
InputParam("prompt_attention_mask", required=True, type_hint=torch.Tensor),
|
||||
InputParam("negative_prompt_embeds", type_hint=torch.Tensor),
|
||||
InputParam("negative_prompt_attention_mask", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
"connector_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Video text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_audio_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Audio text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Attention mask from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative video text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative audio text embeddings from connector",
|
||||
),
|
||||
OutputParam(
|
||||
"connector_negative_attention_mask",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="Negative attention mask from connector",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
prompt_embeds = block_state.prompt_embeds
|
||||
prompt_attention_mask = block_state.prompt_attention_mask
|
||||
negative_prompt_embeds = block_state.negative_prompt_embeds
|
||||
negative_prompt_attention_mask = block_state.negative_prompt_attention_mask
|
||||
|
||||
do_cfg = negative_prompt_embeds is not None
|
||||
|
||||
if do_cfg:
|
||||
combined_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
combined_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
|
||||
else:
|
||||
combined_embeds = prompt_embeds
|
||||
combined_mask = prompt_attention_mask
|
||||
|
||||
connector_embeds, connector_audio_embeds, connector_mask = components.connectors(
|
||||
combined_embeds, combined_mask
|
||||
)
|
||||
|
||||
if do_cfg:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
block_state.connector_negative_prompt_embeds = connector_embeds[:batch_size]
|
||||
block_state.connector_prompt_embeds = connector_embeds[batch_size:]
|
||||
block_state.connector_audio_negative_prompt_embeds = connector_audio_embeds[:batch_size]
|
||||
block_state.connector_audio_prompt_embeds = connector_audio_embeds[batch_size:]
|
||||
block_state.connector_negative_attention_mask = connector_mask[:batch_size]
|
||||
block_state.connector_attention_mask = connector_mask[batch_size:]
|
||||
else:
|
||||
block_state.connector_prompt_embeds = connector_embeds
|
||||
block_state.connector_audio_prompt_embeds = connector_audio_embeds
|
||||
block_state.connector_attention_mask = connector_mask
|
||||
block_state.connector_negative_prompt_embeds = None
|
||||
block_state.connector_audio_negative_prompt_embeds = None
|
||||
block_state.connector_negative_attention_mask = None
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2ConditionEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Condition encoder step that VAE-encodes conditioning frames for I2V and conditional generation"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("conditions", type_hint=list, description="List of LTX2VideoCondition objects"),
|
||||
InputParam("image", type_hint=PIL.Image.Image, description="Sugar for I2V: image to condition on frame 0"),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("condition_latents", type_hint=list, description="List of packed condition latent tensors"),
|
||||
OutputParam("condition_strengths", type_hint=list, description="List of conditioning strengths"),
|
||||
OutputParam("condition_indices", type_hint=list, description="List of latent frame indices"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
conditions = block_state.conditions
|
||||
image = block_state.image
|
||||
|
||||
# Convert image sugar to conditions list
|
||||
if image is not None and conditions is None:
|
||||
conditions = [LTX2VideoCondition(frames=image, index=0, strength=1.0)]
|
||||
|
||||
if conditions is None:
|
||||
block_state.condition_latents = []
|
||||
block_state.condition_strengths = []
|
||||
block_state.condition_indices = []
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
if isinstance(conditions, LTX2VideoCondition):
|
||||
conditions = [conditions]
|
||||
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
device = components._execution_device
|
||||
generator = block_state.generator
|
||||
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
transformer_spatial_patch_size = 1
|
||||
transformer_temporal_patch_size = 1
|
||||
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
|
||||
conditioning_frames, conditioning_strengths, conditioning_indices = [], [], []
|
||||
|
||||
for i, condition in enumerate(conditions):
|
||||
if isinstance(condition.frames, PIL.Image.Image):
|
||||
video_like_cond = [condition.frames]
|
||||
elif isinstance(condition.frames, np.ndarray) and condition.frames.ndim == 3:
|
||||
video_like_cond = np.expand_dims(condition.frames, axis=0)
|
||||
elif isinstance(condition.frames, torch.Tensor) and condition.frames.ndim == 3:
|
||||
video_like_cond = condition.frames.unsqueeze(0)
|
||||
else:
|
||||
video_like_cond = condition.frames
|
||||
|
||||
condition_pixels = components.video_processor.preprocess_video(
|
||||
video_like_cond, height, width, resize_mode="crop"
|
||||
)
|
||||
|
||||
latent_start_idx = condition.index
|
||||
if latent_start_idx < 0:
|
||||
latent_start_idx = latent_start_idx % latent_num_frames
|
||||
if latent_start_idx >= latent_num_frames:
|
||||
logger.warning(
|
||||
f"The starting latent index {latent_start_idx} of condition {i} is too big for {latent_num_frames} "
|
||||
f"latent frames. This condition will be skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
cond_num_frames = condition_pixels.size(2)
|
||||
start_idx = max((latent_start_idx - 1) * vae_temporal_compression_ratio + 1, 0)
|
||||
frame_num = min(cond_num_frames, num_frames - start_idx)
|
||||
frame_num = (frame_num - 1) // vae_temporal_compression_ratio * vae_temporal_compression_ratio + 1
|
||||
condition_pixels = condition_pixels[:, :, :frame_num]
|
||||
|
||||
conditioning_frames.append(condition_pixels.to(dtype=components.vae.dtype, device=device))
|
||||
conditioning_strengths.append(condition.strength)
|
||||
conditioning_indices.append(latent_start_idx)
|
||||
|
||||
condition_latents = []
|
||||
for condition_tensor in conditioning_frames:
|
||||
condition_latent = retrieve_latents(
|
||||
components.vae.encode(condition_tensor), generator=generator, sample_mode="argmax"
|
||||
)
|
||||
condition_latent = _normalize_latents(
|
||||
condition_latent, components.vae.latents_mean, components.vae.latents_std
|
||||
).to(device=device, dtype=torch.float32)
|
||||
condition_latent = _pack_latents(
|
||||
condition_latent, transformer_spatial_patch_size, transformer_temporal_patch_size
|
||||
)
|
||||
condition_latents.append(condition_latent)
|
||||
|
||||
block_state.condition_latents = condition_latents
|
||||
block_state.condition_strengths = conditioning_strengths
|
||||
block_state.condition_indices = conditioning_indices
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
436
src/diffusers/modular_pipelines/ltx2/modular_blocks_ltx2.py
Normal file
436
src/diffusers/modular_pipelines/ltx2/modular_blocks_ltx2.py
Normal file
@@ -0,0 +1,436 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...guiders import ClassifierFreeGuidance
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import ComponentSpec, OutputParam
|
||||
from .before_denoise import (
|
||||
LTX2DisableAdapterStep,
|
||||
LTX2EnableAdapterStep,
|
||||
LTX2InputStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2PrepareLatentsStep,
|
||||
LTX2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import LTX2AudioDecoderStep, LTX2VideoDecoderStep
|
||||
from .denoise import LTX2DenoiseLoopWrapper, LTX2DenoiseStep, LTX2LoopAfterDenoiser, LTX2LoopBeforeDenoiser, LTX2LoopDenoiser
|
||||
from .encoders import LTX2ConditionEncoderStep, LTX2ConnectorStep, LTX2TextEncoderStep
|
||||
from .modular_blocks_ltx2_upsample import LTX2UpsampleCoreBlocks
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. AUTO CONDITION ENCODER (skip if no conditions)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2AutoConditionEncoderStep(AutoPipelineBlocks):
|
||||
"""Auto block that runs condition encoding when conditions or image inputs are provided.
|
||||
|
||||
- When `conditions` is provided: runs condition encoder for arbitrary frame conditioning
|
||||
- When `image` is provided: runs condition encoder (converts image to condition at frame 0)
|
||||
- When neither is provided: step is skipped (T2V mode)
|
||||
"""
|
||||
|
||||
block_classes = [LTX2ConditionEncoderStep, LTX2ConditionEncoderStep]
|
||||
block_names = ["conditional_encoder", "image_encoder"]
|
||||
block_trigger_inputs = ["conditions", "image"]
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoising block: input prep -> timesteps -> latents -> audio latents -> coordinates -> denoise loop."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2SetTimestepsStep,
|
||||
LTX2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2DenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise block that takes encoded conditions and runs the full denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. BLOCKS (T2V only)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Blocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 text-to-video generation."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX2 text-to-video generation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS (T2V + I2V + Conditional)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2AutoBlocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 with unified T2V, I2V, and conditional generation.
|
||||
|
||||
Workflow map:
|
||||
- text2video: prompt only
|
||||
- image2video: image + prompt (auto-converts to condition at frame 0)
|
||||
- conditional: conditions + prompt (arbitrary frame conditioning)
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2AutoConditionEncoderStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "condition_encoder", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Unified modular pipeline blocks for LTX2 supporting text-to-video, "
|
||||
"image-to-video, and conditional/FLF2V generation."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. STAGE 2 CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoise for Stage 2: uses distilled sigmas with no dynamic shifting."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2DenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 core denoise block using distilled sigmas and no dynamic shifting."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 6. STAGE 1 BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage1Blocks(SequentialPipelineBlocks):
|
||||
"""Stage 1 blocks: text encoding -> conditioning -> denoise -> latent output.
|
||||
|
||||
Outputs latents and audio_latents for downstream processing (upsample + stage2).
|
||||
Supports T2V, I2V, and conditional generation modes.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2AutoConditionEncoderStep,
|
||||
LTX2CoreDenoiseStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "condition_encoder", "denoise"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Stage 1 modular pipeline blocks for LTX2: text encoding, conditioning, "
|
||||
"and denoising. Outputs latents for upsample + stage2 workflow."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latents"), OutputParam("audio_latents")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 7. STAGE 2 BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2Blocks(SequentialPipelineBlocks):
|
||||
"""Stage 2 blocks: text encoding -> denoise (distilled) -> decode video + audio.
|
||||
|
||||
Expects pre-computed latents (from upsample) and audio_latents (from stage1).
|
||||
Uses fixed distilled sigmas with no dynamic shifting and guidance_scale=1.0.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2TextEncoderStep,
|
||||
LTX2ConnectorStep,
|
||||
LTX2Stage2CoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "connector", "denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Stage 2 modular pipeline blocks for LTX2: re-encodes text, "
|
||||
"denoises with distilled sigmas, and decodes video + audio."
|
||||
)
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
# Override guider default for stage2: guidance_scale=1.0 (no CFG)
|
||||
components = [
|
||||
ComponentSpec(
|
||||
"guider",
|
||||
ClassifierFreeGuidance,
|
||||
config=FrozenDict({"guidance_scale": 1.0}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
for block in self.sub_blocks.values():
|
||||
for component in block.expected_components:
|
||||
if component not in components:
|
||||
components.append(component)
|
||||
return components
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 8. STAGE 2 FULL DENOISE (uses stage2_guider)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2FullDenoiseStep(LTX2DenoiseLoopWrapper):
|
||||
"""Denoise step for Stage 2 within the full pipeline, using stage2_guider (guidance_scale=1.0)."""
|
||||
|
||||
block_classes = [
|
||||
LTX2LoopBeforeDenoiser,
|
||||
LTX2LoopDenoiser(
|
||||
guider_name="stage2_guider",
|
||||
guider_config=FrozenDict({"guidance_scale": 1.0}),
|
||||
guider_input_fields={
|
||||
"encoder_hidden_states": ("connector_prompt_embeds", "connector_negative_prompt_embeds"),
|
||||
"audio_encoder_hidden_states": (
|
||||
"connector_audio_prompt_embeds",
|
||||
"connector_audio_negative_prompt_embeds",
|
||||
),
|
||||
"encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
"audio_encoder_attention_mask": ("connector_attention_mask", "connector_negative_attention_mask"),
|
||||
},
|
||||
),
|
||||
LTX2LoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Stage 2 denoise step using stage2_guider (guidance_scale=1.0).\n"
|
||||
"Used within LTX2FullPipelineBlocks to avoid conflict with the Stage 1 guider."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 9. STAGE 2 FULL CORE DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2FullCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""Core denoise for Stage 2 within the full pipeline: distilled sigmas, no dynamic shifting, stage2_guider."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2InputStep,
|
||||
LTX2Stage2SetTimestepsStep,
|
||||
LTX2Stage2PrepareLatentsStep,
|
||||
LTX2PrepareAudioLatentsStep,
|
||||
LTX2PrepareCoordinatesStep,
|
||||
LTX2Stage2FullDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"set_timesteps",
|
||||
"prepare_latents",
|
||||
"prepare_audio_latents",
|
||||
"prepare_coordinates",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 core denoise for full pipeline: distilled sigmas, no dynamic shifting, stage2_guider."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam("latents"),
|
||||
OutputParam("audio_latents"),
|
||||
]
|
||||
|
||||
|
||||
# ====================
|
||||
# 10. STAGE 2 INTERNAL BLOCKS (no text encoder/connector)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2Stage2InternalBlocks(SequentialPipelineBlocks):
|
||||
"""Stage 2 blocks without text encoder/connector — reads connector_* embeddings from state.
|
||||
|
||||
Used within LTX2FullPipelineBlocks where Stage 1 already encoded text.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2Stage2FullCoreDenoiseStep,
|
||||
LTX2VideoDecoderStep,
|
||||
LTX2AudioDecoderStep,
|
||||
]
|
||||
block_names = ["denoise", "video_decode", "audio_decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Stage 2 internal blocks (no text encoding): denoise with stage2_guider + decode."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
|
||||
|
||||
# ====================
|
||||
# 11. FULL PIPELINE BLOCKS (all-in-one)
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2FullPipelineBlocks(SequentialPipelineBlocks):
|
||||
"""All-in-one mega-block: stage1 -> upsample -> stage2 in a single pipe() call.
|
||||
|
||||
LoRA adapters are automatically disabled for stage1 and re-enabled for stage2.
|
||||
Uses two guiders: `guider` (guidance_scale=4.0) for stage1 and
|
||||
`stage2_guider` (guidance_scale=1.0) for stage2.
|
||||
|
||||
Required components: text_encoder, tokenizer, transformer, connectors, vae, audio_vae,
|
||||
vocoder, scheduler, guider, stage2_guider, latent_upsampler.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2DisableAdapterStep,
|
||||
LTX2Stage1Blocks,
|
||||
LTX2UpsampleCoreBlocks,
|
||||
LTX2EnableAdapterStep,
|
||||
LTX2Stage2InternalBlocks,
|
||||
]
|
||||
block_names = ["disable_lora", "stage1", "upsample", "enable_lora", "stage2"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2video": {"prompt": True},
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"conditional": {"conditions": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"All-in-one LTX2 pipeline: stage1 (denoise) -> upsample -> stage2 (distilled denoise + decode). "
|
||||
"LoRA adapters toggled automatically via stage_2_adapter parameter."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos"), OutputParam("audio")]
|
||||
@@ -0,0 +1,373 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models.autoencoders import AutoencoderKLLTX2Video
|
||||
from ...pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: torch.Generator | None = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
|
||||
class LTX2UpsamplePrepareStep(ModularPipelineBlocks):
|
||||
"""Prepare latents for upsampling: accepts either video frames or pre-computed latents."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare latents for the latent upsampler, from either video input or pre-computed latents"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("video", description="Video frames to encode and upsample"),
|
||||
InputParam("latents", type_hint=torch.Tensor, description="Pre-computed latents to upsample"),
|
||||
InputParam("latents_normalized", default=False, type_hint=bool),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("temporal_patch_size", default=1, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="Prepared latents for upsampling"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
device = components._execution_device
|
||||
|
||||
video = block_state.video
|
||||
latents = block_state.latents
|
||||
height = block_state.height
|
||||
width = block_state.width
|
||||
num_frames = block_state.num_frames
|
||||
generator = block_state.generator
|
||||
|
||||
vae_spatial_compression_ratio = components.vae.spatial_compression_ratio
|
||||
vae_temporal_compression_ratio = components.vae.temporal_compression_ratio
|
||||
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
latent_num_frames = (num_frames - 1) // vae_temporal_compression_ratio + 1
|
||||
latent_height = height // vae_spatial_compression_ratio
|
||||
latent_width = width // vae_spatial_compression_ratio
|
||||
latents = _unpack_latents(
|
||||
latents,
|
||||
latent_num_frames,
|
||||
latent_height,
|
||||
latent_width,
|
||||
block_state.spatial_patch_size,
|
||||
block_state.temporal_patch_size,
|
||||
)
|
||||
if block_state.latents_normalized:
|
||||
latents = _denormalize_latents(
|
||||
latents,
|
||||
components.vae.latents_mean,
|
||||
components.vae.latents_std,
|
||||
components.vae.config.scaling_factor,
|
||||
)
|
||||
block_state.latents = latents.to(device=device, dtype=torch.float32)
|
||||
elif video is not None:
|
||||
if isinstance(video, list):
|
||||
num_frames = len(video)
|
||||
if num_frames % vae_temporal_compression_ratio != 1:
|
||||
num_frames = num_frames // vae_temporal_compression_ratio * vae_temporal_compression_ratio + 1
|
||||
if isinstance(video, list):
|
||||
video = video[:num_frames]
|
||||
|
||||
video = components.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(components.vae.encode(vid.unsqueeze(0)), generator) for vid in video
|
||||
]
|
||||
block_state.latents = torch.cat(init_latents, dim=0).to(torch.float32)
|
||||
else:
|
||||
raise ValueError("One of `video` or `latents` must be provided.")
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2UpsampleStep(ModularPipelineBlocks):
|
||||
"""Run the latent upsampler model with optional AdaIN and tone mapping."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Run the latent upsampler model with optional AdaIN filtering and tone mapping"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("latent_upsampler", LTX2LatentUpsamplerModel),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("adain_factor", default=0.0, type_hint=float),
|
||||
InputParam("tone_map_compression_ratio", default=0.0, type_hint=float),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("latents", type_hint=torch.Tensor, description="Upsampled latents"),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def adain_filter_latent(latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
result = latents.clone()
|
||||
for i in range(latents.size(0)):
|
||||
for c in range(latents.size(1)):
|
||||
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None)
|
||||
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
|
||||
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
|
||||
result = torch.lerp(latents, result, factor)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def tone_map_latents(latents: torch.Tensor, compression: float) -> torch.Tensor:
|
||||
scale_factor = compression * 0.75
|
||||
abs_latents = torch.abs(latents)
|
||||
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||
return latents * scales
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents.to(components.latent_upsampler.dtype)
|
||||
reference_latents = latents
|
||||
|
||||
latents_upsampled = components.latent_upsampler(latents)
|
||||
|
||||
if block_state.adain_factor > 0.0:
|
||||
latents = self.adain_filter_latent(latents_upsampled, reference_latents, block_state.adain_factor)
|
||||
else:
|
||||
latents = latents_upsampled
|
||||
|
||||
if block_state.tone_map_compression_ratio > 0.0:
|
||||
latents = self.tone_map_latents(latents, block_state.tone_map_compression_ratio)
|
||||
|
||||
block_state.latents = latents
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class LTX2UpsamplePostprocessStep(ModularPipelineBlocks):
|
||||
"""Decode upsampled latents to video frames."""
|
||||
|
||||
model_name = "ltx2"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Decode upsampled latents into video frames"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> list[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLLTX2Video),
|
||||
ComponentSpec(
|
||||
"video_processor",
|
||||
VideoProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 32}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
InputParam("decode_timestep", default=0.0),
|
||||
InputParam("decode_noise_scale", default=None),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[OutputParam]:
|
||||
return [
|
||||
OutputParam("videos", description="Decoded video frames"),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
if block_state.output_type == "latent":
|
||||
block_state.videos = latents
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = latents.device
|
||||
|
||||
if not components.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(
|
||||
latents.shape, generator=block_state.generator, device=device, dtype=latents.dtype
|
||||
)
|
||||
decode_timestep = block_state.decode_timestep
|
||||
decode_noise_scale = block_state.decode_noise_scale
|
||||
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = components.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
block_state.videos = components.video_processor.postprocess_video(
|
||||
video, output_type=block_state.output_type
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# ====================
|
||||
# UPSAMPLE BLOCKS
|
||||
# ====================
|
||||
|
||||
|
||||
class LTX2UpsampleBlocks(SequentialPipelineBlocks):
|
||||
"""Modular pipeline blocks for LTX2 latent upsampling."""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2UpsamplePrepareStep,
|
||||
LTX2UpsampleStep,
|
||||
LTX2UpsamplePostprocessStep,
|
||||
]
|
||||
block_names = ["prepare", "upsample", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline blocks for LTX2 latent upsampling (stage1 -> upsample -> stage2)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("videos")]
|
||||
|
||||
|
||||
class LTX2UpsampleCorePrepareStep(LTX2UpsamplePrepareStep):
|
||||
"""Upsample prepare step for the full pipeline: latents_normalized defaults to True."""
|
||||
|
||||
@property
|
||||
def inputs(self) -> list[InputParam]:
|
||||
return [
|
||||
InputParam("video", description="Video frames to encode and upsample"),
|
||||
InputParam("latents", type_hint=torch.Tensor, description="Pre-computed latents to upsample"),
|
||||
InputParam("latents_normalized", default=True, type_hint=bool),
|
||||
InputParam("height", default=512, type_hint=int),
|
||||
InputParam("width", default=768, type_hint=int),
|
||||
InputParam("num_frames", default=121, type_hint=int),
|
||||
InputParam("spatial_patch_size", default=1, type_hint=int),
|
||||
InputParam("temporal_patch_size", default=1, type_hint=int),
|
||||
InputParam("generator"),
|
||||
]
|
||||
|
||||
|
||||
class LTX2UpsampleCoreBlocks(SequentialPipelineBlocks):
|
||||
"""Upsample blocks for the full pipeline: prepare + upsample only (no decode).
|
||||
|
||||
Outputs 5D latents (not decoded video), suitable for chaining into Stage2.
|
||||
"""
|
||||
|
||||
model_name = "ltx2"
|
||||
block_classes = [
|
||||
LTX2UpsampleCorePrepareStep,
|
||||
LTX2UpsampleStep,
|
||||
]
|
||||
block_names = ["prepare", "upsample"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Latent upsample blocks (no decode) for use within the full pipeline."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam("latents")]
|
||||
112
src/diffusers/modular_pipelines/ltx2/modular_pipeline.py
Normal file
112
src/diffusers/modular_pipelines/ltx2/modular_pipeline.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...loaders import LTX2LoraLoaderMixin
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class LTX2ModularPipeline(
|
||||
ModularPipeline,
|
||||
LTX2LoraLoaderMixin,
|
||||
):
|
||||
"""
|
||||
A ModularPipeline for LTX2 video generation (T2V, I2V, Conditional/FLF2V).
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "LTX2AutoBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 768
|
||||
|
||||
@property
|
||||
def default_num_frames(self):
|
||||
return 121
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 32
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.spatial_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def vae_scale_factor_temporal(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.temporal_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def transformer_spatial_patch_size(self):
|
||||
patch_size = 1
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size = self.transformer.config.patch_size
|
||||
return patch_size
|
||||
|
||||
@property
|
||||
def transformer_temporal_patch_size(self):
|
||||
patch_size = 1
|
||||
if hasattr(self, "transformer") and self.transformer is not None:
|
||||
patch_size = self.transformer.config.patch_size_t
|
||||
return patch_size
|
||||
|
||||
@property
|
||||
def requires_unconditional_embeds(self):
|
||||
requires = False
|
||||
if hasattr(self, "guider") and self.guider is not None:
|
||||
requires = self.guider._enabled and self.guider.num_conditions > 1
|
||||
return requires
|
||||
|
||||
|
||||
class LTX2UpsampleModularPipeline(ModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for LTX2 latent upsampling.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "LTX2UpsampleBlocks"
|
||||
|
||||
@property
|
||||
def default_height(self):
|
||||
return 512
|
||||
|
||||
@property
|
||||
def default_width(self):
|
||||
return 768
|
||||
|
||||
@property
|
||||
def vae_scale_factor_spatial(self):
|
||||
vae_scale_factor = 32
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.spatial_compression_ratio
|
||||
return vae_scale_factor
|
||||
|
||||
@property
|
||||
def vae_scale_factor_temporal(self):
|
||||
vae_scale_factor = 8
|
||||
if hasattr(self, "vae") and self.vae is not None:
|
||||
vae_scale_factor = self.vae.temporal_compression_ratio
|
||||
return vae_scale_factor
|
||||
@@ -132,6 +132,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("z-image", _create_default_map_fn("ZImageModularPipeline")),
|
||||
("helios", _create_default_map_fn("HeliosModularPipeline")),
|
||||
("helios-pyramid", _helios_pyramid_map_fn),
|
||||
("ltx2", _create_default_map_fn("LTX2ModularPipeline")),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ from .stable_diffusion_xl import (
|
||||
StableDiffusionXLInpaintPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2Pipeline
|
||||
from .wan import WanImageToVideoPipeline, WanPipeline, WanVideoToVideoPipeline
|
||||
from .wuerstchen import WuerstchenCombinedPipeline, WuerstchenDecoderPipeline
|
||||
from .z_image import (
|
||||
@@ -247,6 +248,7 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
||||
|
||||
AUTO_TEXT2VIDEO_PIPELINES_MAPPING = OrderedDict(
|
||||
[
|
||||
("ltx2", LTX2Pipeline),
|
||||
("wan", WanPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -145,7 +145,7 @@ def encode_video(
|
||||
# Pipeline output_type="np"
|
||||
is_denormalized = np.logical_and(np.zeros_like(video) <= video, video <= np.ones_like(video))
|
||||
if np.all(is_denormalized):
|
||||
video = (video * 255).round().astype("uint8")
|
||||
video = (video * 255).astype("uint8")
|
||||
else:
|
||||
logger.warning(
|
||||
"Supplied `numpy.ndarray` does not have values in [0, 1]. The values will be assumed to be pixel "
|
||||
|
||||
@@ -274,10 +274,14 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
`torch.Tensor`:
|
||||
A tensor of adjusted timesteps such that the final value equals `self.config.shift_terminal`.
|
||||
"""
|
||||
one_minus_z = 1 - t
|
||||
scale_factor = one_minus_z[-1] / (1 - self.config.shift_terminal)
|
||||
stretched_t = 1 - (one_minus_z / scale_factor)
|
||||
return stretched_t
|
||||
# Compute in float32 (matching reference ltx_core scheduler) to avoid
|
||||
# float64 intermediates from numpy scalar / Python float promotion.
|
||||
is_numpy = isinstance(t, np.ndarray)
|
||||
t_tensor = torch.as_tensor(t, dtype=torch.float32)
|
||||
one_minus_z = 1.0 - t_tensor
|
||||
scale_factor = one_minus_z[-1] / (1.0 - self.config.shift_terminal)
|
||||
stretched_t = 1.0 - (one_minus_z / scale_factor)
|
||||
return stretched_t.numpy() if is_numpy else stretched_t
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
@@ -510,7 +514,7 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
noise = torch.randn_like(sample)
|
||||
prev_sample = (1.0 - next_sigma) * x0 + next_sigma * noise
|
||||
else:
|
||||
prev_sample = sample + dt * model_output
|
||||
prev_sample = sample + model_output.to(sample.dtype) * dt
|
||||
|
||||
# upon completion increase step index by one
|
||||
self._step_index += 1
|
||||
@@ -646,7 +650,12 @@ class FlowMatchEulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
return sigmas
|
||||
|
||||
def _time_shift_exponential(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
|
||||
# Compute in float32 (matching reference ltx_core scheduler) to avoid
|
||||
# float64 intermediate precision from math.exp() + numpy promotion.
|
||||
t_tensor = torch.as_tensor(t, dtype=torch.float32)
|
||||
exp_mu = math.exp(mu)
|
||||
result = exp_mu / (exp_mu + (1 / t_tensor - 1) ** sigma)
|
||||
return result.numpy() if isinstance(t, np.ndarray) else result
|
||||
|
||||
def _time_shift_linear(self, mu: float, sigma: float, t: torch.Tensor) -> torch.Tensor:
|
||||
return mu / (mu + (1 / t - 1) ** sigma)
|
||||
|
||||
Reference in New Issue
Block a user