Compare commits

...

27 Commits

Author SHA1 Message Date
DN6
2d5080d7c6 update 2026-01-10 00:28:24 +05:30
David El Malih
d36564f06a Improve docstrings and type hints in scheduling_consistency_models.py (#12931)
docs: improve docstring scheduling_consistency_models.py
2026-01-09 09:56:56 -08:00
Sayak Paul
441b69eabf [core] Handle progress bar and logging in distributed environments (#12806)
* disable progressbar in distributed.

* up

* up

* up

* up

* up

* up
2026-01-09 22:23:13 +05:30
Sayak Paul
d568c9773f [chore] remove controlnet implementations outside controlnet module. (#12152)
* remove controlnet implementations outside controlnet module.

* fix

* fix

* fix
2026-01-09 21:22:45 +05:30
Sayak Paul
3981c955ce [modular] Tests for custom blocks in modular diffusers (#12557)
* start custom block testing.

* simplify modular workflow ci.

* up

* style.

* up

* up

* up

* up

* up

* up

* Apply suggestions from code review

* up
2026-01-09 15:57:23 +05:30
YiYi Xu
1903383e94 [Modular] qwen refactor (#12872)
* 3 files

* add conditoinal pipeline

* refactor qwen modular

* add layered

* up up

* u p

* add to import

* more refacotr, make layer work

* clean up a bit git add src

* more

* style

* style
2026-01-08 23:38:49 -10:00
Leo Jiang
08f8b7af9a Bugfix for dreambooth flux2 img2img2 (#12825)
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
2026-01-09 12:34:44 +05:30
Howard Zhang
2f66edc880 Torchao floatx version guard (#12923)
* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

* Adding torchao version guard for floatx usage

Summary: TorchAO removing floatx support, added version guard in quantization_config.py
Altered tests in test_torchao.py to version guard floatx
Created new test to verify version guard of floatx support

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-09 10:51:53 +05:30
TmacAaron
be38f41f9f [NPU] npu attention enable ulysses (#12919)
* npu attention enable ulysses

* clean the format

* register _native_npu_attention to _supports_context_parallel

Signed-off-by: yyt <yangyit139@gmail.com>

* change npu_fusion_attention's input_layout to BSND to eliminate redundant transpose

Signed-off-by: yyt <yangyit139@gmail.com>

* Update format

---------

Signed-off-by: yyt <yangyit139@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-09 10:11:49 +05:30
MSD
91e5134175 fix the warning torch_dtype is deprecated (#12841)
* fix the warning torch_dtype is deprecated

* Add transformers version check (>= 4.56.0) for dtype parameter

* Fix linting errors
2026-01-09 08:35:26 +05:30
Salman Chishti
a812c87465 Upgrade GitHub Actions for Node 24 compatibility (#12865)
Signed-off-by: Salman Muin Kayser Chishti <13schishti@gmail.com>
2026-01-09 08:28:58 +05:30
Aditya Borate
8b9f817ef5 Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning (#12814)
* Fix: Remove hardcoded CUDA autocast in Kandinsky 5 to fix import warning

* Apply style fixes

* Fix: Remove import-time autocast in Kandinsky to prevent warnings

- Removed @torch.autocast decorator from Kandinsky classes.
- Implemented manual F.linear casting to ensure numerical parity with FP32.
- Verified bit-exact output matches main branch.

Co-authored-by: hlky <hlky@hlky.ac>

* Used _keep_in_fp32_modules to align with standards

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: hlky <hlky@hlky.ac>
2026-01-08 09:00:58 -10:00
David El Malih
b1f06b780a Improve docstrings and type hints in scheduling_consistency_decoder.py (#12928)
docs: improve docstring scheduling_consistency_decoder.py
2026-01-08 09:45:38 -08:00
Pauline Bailly-Masson
8600b4c10d Add environment variables to checkout step (#12927) 2026-01-08 13:38:06 +05:30
dg845
c10bdd9b73 Add LTX 2.0 Video Pipelines (#12915)
* Initial LTX 2.0 transformer implementation

* Add tests for LTX 2 transformer model

* Get LTX 2 transformer tests working

* Rename LTX 2 compile test class to have LTX2

* Remove RoPE debug print statements

* Get LTX 2 transformer compile tests passing

* Fix LTX 2 transformer shape errors

* Initial script to convert LTX 2 transformer to diffusers

* Add more LTX 2 transformer audio arguments

* Allow LTX 2 transformer to be loaded from local path for conversion

* Improve dummy inputs and add test for LTX 2 transformer consistency

* Fix LTX 2 transformer bugs so consistency test passes

* Initial implementation of LTX 2.0 video VAE

* Explicitly specify temporal and spatial VAE scale factors when converting

* Add initial LTX 2.0 video VAE tests

* Add initial LTX 2.0 video VAE tests (part 2)

* Get diffusers implementation on par with official LTX 2.0 video VAE implementation

* Initial LTX 2.0 vocoder implementation

* Use RMSNorm implementation closer to original for LTX 2.0 video VAE

* start audio decoder.

* init registration.

* up

* simplify and clean up

* up

* Initial LTX 2.0 text encoder implementation

* Rough initial LTX 2.0 pipeline implementation

* up

* up

* up

* up

* Add imports for LTX 2.0 Audio VAE

* Conversion script for LTX 2.0 Audio VAE Decoder

* Add Audio VAE logic to T2V pipeline

* Duplicate scheduler for audio latents

* Support num_videos_per_prompt for prompt embeddings

* LTX 2.0 scheduler and full pipeline conversion

* Add script to test full LTX2Pipeline T2V inference

* Fix pipeline return bugs

* Add LTX 2 text encoder and vocoder to ltx2 subdirectory __init__

* Fix more bugs in LTX2Pipeline.__call__

* Improve CPU offload support

* Fix pipeline audio VAE decoding dtype bug

* Fix video shape error in full pipeline test script

* Get LTX 2 T2V pipeline to produce reasonable outputs

* Make LTX 2.0 scheduler more consistent with original code

* Fix typo when applying scheduler fix in T2V inference script

* Refactor Audio VAE to be simpler and remove helpers (#7)

* remove resolve causality axes stuff.

* remove a bunch of helpers.

* remove adjust output shape helper.

* remove the use of audiolatentshape.

* move normalization and patchify out of pipeline.

* fix

* up

* up

* Remove unpatchify and patchify ops before audio latents denormalization (#9)

---------

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Add support for I2V (#8)

* start i2v.

* up

* up

* up

* up

* up

* remove uniform strategy code.

* remove unneeded code.

* Denormalize audio latents in I2V pipeline (analogous to T2V change) (#11)

* test i2v.

* Move Video and Audio Text Encoder Connectors to Transformer (#12)

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* precompute run_connectors,.

* fixes

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* Make connectors a separate module (#18)

* remove text_encoder.py

* address yiyi's comments.

* up

* up

* up

* up

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>

* up (#19)

* address initial feedback from lightricks team (#16)

* cross_attn_timestep_scale_multiplier to 1000

* implement split rope type.

* up

* propagate rope_type to rope embed classes as well.

* up

* When using split RoPE, make sure that the output dtype is same as input dtype

* Fix apply split RoPE shape error when reshaping x to 4D

* Add export_utils file for exporting LTX 2.0 videos with audio

* Tests for T2V and I2V (#6)

* add ltx2 pipeline tests.

* up

* up

* up

* up

* remove content

* style

* Denormalize audio latents in I2V pipeline (analogous to T2V change)

* Initial refactor to put video and audio text encoder connectors in transformer

* Get LTX 2 transformer tests working after connector refactor

* up

* up

* i2v tests.

* up

* Address review comments

* Calculate RoPE double precisions freqs using torch instead of np

* Further simplify LTX 2 RoPE freq calc

* revert unneded changes.

* up

* up

* update to split style rope.

* up

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* up

* use export util funcs.

* Point original checkpoint to LTX 2.0 official checkpoint

* Allow the I2V pipeline to accept image URLs

* make style and make quality

* remove function map.

* remove args.

* update docs.

* update doc entries.

* disable ltx2_consistency test

* Simplify LTX 2 RoPE forward by removing coords is None logic

* make style and make quality

* Support LTX 2.0 audio VAE encoder

* Apply suggestions from code review

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Remove print statement in audio VAE

* up

* Fix bug when calculating audio RoPE coords

* Ltx 2 latent upsample pipeline (#12922)

* Initial implementation of LTX 2.0 latent upsampling pipeline

* Add new LTX 2.0 spatial latent upsampler logic

* Add test script for LTX 2.0 latent upsampling

* Add option to enable VAE tiling in upsampling test script

* Get latent upsampler working with video latents

* Fix typo in BlurDownsample

* Add latent upsample pipeline docstring and example

* Remove deprecated pipeline VAE slicing/tiling methods

* make style and make quality

* When returning latents, return unpacked and denormalized latents for T2V and I2V

* Add model_cpu_offload_seq for latent upsampling pipeline

---------

Co-authored-by: Daniel Gu <dgu8957@gmail.com>

* Fix latent upsampler filename in LTX 2 conversion script

* Add latent upsample pipeline to LTX 2 docs

* Add dummy objects for LTX 2 latent upsample pipeline

* Set default FPS to official LTX 2 ckpt default of 24.0

* Set default CFG scale to official LTX 2 ckpt default of 4.0

* Update LTX 2 pipeline example docstrings

* make style and make quality

* Remove LTX 2 test scripts

* Fix LTX 2 upsample pipeline example docstring

* Add logic to convert and save a LTX 2 upsampling pipeline

* Document LTX2VideoTransformer3DModel forward pass

---------

Co-authored-by: sayakpaul <spsayakpaul@gmail.com>
2026-01-07 21:24:27 -08:00
Álvaro Somoza
dab000e88b [Modular] Video for Mellon (#12924)
num_frames and videos
2026-01-07 12:35:59 -10:00
David El Malih
9fb6b89d49 Improve docstrings and type hints in scheduling_edm_euler.py (#12871)
* docs: add comprehensive docstrings and refine type hints for EDM scheduler methods and config parameters.

* refactor: Add type hints to DPM-Solver scheduler methods.
2026-01-07 11:18:00 -08:00
Sayak Paul
6fb4c99f5a Update wan.md to remove unneeded hfoptions (#12890) 2026-01-07 09:47:19 -08:00
Sayak Paul
961b9b27d3 [docs] fix torchao typo. (#12883)
fix torchao typo.
2026-01-07 09:43:02 -08:00
Tolga Cangöz
8f30bfff1f Add transformer cache context for SkyReels-V2 pipelines & Update docs (#12837)
* feat: Add transformer cache context for conditional and unconditional predictions for skyreels-v2 pipes.

* docs: Remove SkyReels-V2 FLF2V model link and add contributor attribution.
2026-01-06 22:30:30 -10:00
Leo Jiang
b4be29bda2 Add FSDP option for Flux2 (#12860)
* Add FSDP option for Flux2

* Apply style fixes

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Add FSDP option for Flux2

* Update examples/dreambooth/README_flux2.md

* guard accelerate import.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-01-07 13:17:46 +05:30
Hu Yaoqi
98479a94c2 LTX Video 0.9.8 long multi prompt (#12614)
* LTX Video 0.9.8  long multi prompt

* Further align comfyui

- Added the “LTXEulerAncestralRFScheduler” scheduler, aligned with [sample_euler_ancestral_RF](7d6103325e/comfy/k_diffusion/sampling.py (L234))

- Updated the LTXI2VLongMultiPromptPipeline.from_pretrained() method:
  - Now uses LTXEulerAncestralRFScheduler by default, for better compatibility with the ComfyUI LTXV workflow.

- Changed the default value of cond_strength from 1.0 to 0.5, aligning with ComfyUI’s default.

- Optimized cross-window overlap blending: moved the latent-space guidance injection to before the UNet and after each step, aligned with[KSamplerX0Inpaint]([ComfyUI/comfy/samplers.py at master · comfyanonymous/ComfyUI](https://github.com/comfyanonymous/ComfyUI/blob/master/comfy/samplers.py#L391))

- Adjusted the default value of skip_steps_sigma_threshold to 1.

* align with diffusers contribute rule

* Add new pipelines and update imports

* Enhance LTXI2VLongMultiPromptPipeline with noise rescaling

Refactor LTXI2VLongMultiPromptPipeline to improve documentation and add noise rescaling functionality.

* Clean up comments in scheduling_ltx_euler_ancestral_rf.py

Removed design notes and limitations from the implementation.

* Enhance video generation example with scheduler

Updated LTXI2VLongMultiPromptPipeline example to include LTXEulerAncestralRFScheduler for ComfyUI parity.

* clean up

* style

* copies

* import ltx scheduler

* copies

* fix

* fix more

* up up

* up up up

* up upup

* Apply suggestions from code review

* Update docs/source/en/api/pipelines/ltx_video.md

* Update docs/source/en/api/pipelines/ltx_video.md

---------

Co-authored-by: yiyixuxu <yixu310@gmail.com>
2026-01-06 18:18:04 -10:00
zhangtao0408
ade1059ae2 [Flux.1] improve pos embed for ascend npu by computing on npu (#12897)
* [Flux.1] improve pos embed for ascend npu by setting it back to npu computation.

* [Flux.2] improve pos embed for ascend npu by setting it back to npu computation.

* [LongCat-Image] improve pos embed for ascend npu by setting it back to npu computation.

* [Ovis-Image] improve pos embed for ascend npu by setting it back to npu computation.

* Remove unused import of is_torch_npu_available

---------

Co-authored-by: zhangtao <zhangtao529@huawei.com>
2026-01-06 08:48:04 -10:00
dxqb
41a6e86faf Check for attention mask in backends that don't support it (#12892)
* check attention mask

* Apply style fixes

* bugfix

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-01-06 22:52:12 +05:30
Pauline Bailly-Masson
9b5a244653 CodeQL workflow for security analysis 2026-01-06 17:26:08 +01:00
Pauline Bailly-Masson
417f6b2d33 Delete .github/workflows/codeql.yml 2026-01-06 17:25:38 +01:00
Pauline Bailly-Masson
e46354d2d0 Add codeQL workflow (#12917)
Updated CodeQL workflow to use reusable workflow from Hugging Face and simplified language matrix.
2026-01-06 17:19:48 +01:00
121 changed files with 15908 additions and 3077 deletions

View File

@@ -28,7 +28,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -58,7 +58,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}

View File

@@ -28,7 +28,7 @@ jobs:
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Find Changed Dockerfiles
id: file_changes
@@ -99,7 +99,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub

View File

@@ -17,10 +17,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v4
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v6
with:
python-version: '3.10'

22
.github/workflows/codeql.yml vendored Normal file
View File

@@ -0,0 +1,22 @@
---
name: CodeQL Security Analysis For Github Actions
on:
push:
branches: ["main"]
workflow_dispatch:
# pull_request:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
permissions:
security-events: write
packages: read
actions: read
contents: read
with:
languages: '["actions","python"]'
queries: 'security-extended,security-and-quality'
runner: 'ubuntu-latest' #optional if need custom runner

View File

@@ -38,6 +38,7 @@ jobs:
# If ref is 'refs/heads/main' => set 'main'
# Else it must be a tag => set {tag}
- name: Set checkout_ref and path_in_repo
env:
EVENT_NAME: ${{ github.event_name }}
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
GITHUB_REF: ${{ github.ref }}
@@ -65,13 +66,13 @@ jobs:
run: |
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
- uses: actions/checkout@v3
- uses: actions/checkout@v6
with:
ref: ${{ env.CHECKOUT_REF }}
# Setup + install dependencies
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies

View File

@@ -28,7 +28,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -97,7 +97,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -119,7 +119,7 @@ jobs:
module: [models, schedulers, lora, others, single_file, examples]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -167,7 +167,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -228,7 +228,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -263,7 +263,7 @@ jobs:
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_big_gpu_test_reports
path: reports
@@ -280,7 +280,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -321,7 +321,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -355,7 +355,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -391,7 +391,7 @@ jobs:
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
@@ -408,7 +408,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -441,7 +441,7 @@ jobs:
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
@@ -466,7 +466,7 @@ jobs:
image: diffusers/diffusers-pytorch-cpu
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -474,7 +474,7 @@ jobs:
run: mkdir -p combined_reports
- name: Download all test reports
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
path: artifacts
@@ -500,7 +500,7 @@ jobs:
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
- name: Upload consolidated report
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: consolidated_test_report
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
@@ -514,7 +514,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v3
# uses: actions/checkout@v6
# with:
# fetch-depth: 2
#
@@ -554,7 +554,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v4
# uses: actions/upload-artifact@v6
# with:
# name: torch_mps_test_reports
# path: reports
@@ -570,7 +570,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v3
# uses: actions/checkout@v6
# with:
# fetch-depth: 2
#
@@ -610,7 +610,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v4
# uses: actions/upload-artifact@v6
# with:
# name: torch_mps_test_reports
# path: reports

View File

@@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: '3.8'

View File

@@ -18,9 +18,9 @@ jobs:
check_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -1,3 +1,4 @@
name: Fast PR tests for Modular
on:
@@ -35,9 +36,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
@@ -55,9 +56,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
@@ -77,23 +78,13 @@ jobs:
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
name: Fast PyTorch Modular Pipeline CPU tests
runs-on:
group: ${{ matrix.config.runner }}
group: aws-highmemory-32-plus
container:
image: ${{ matrix.config.image }}
image: diffusers/diffusers-pytorch-cpu
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
@@ -102,7 +93,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -118,22 +109,19 @@ jobs:
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-k "not Flax and not Onnx" \
--make-reports=tests_${{ matrix.config.report }} \
--make-reports=tests_torch_cpu_modular_pipelines \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
path: reports

View File

@@ -28,7 +28,7 @@ jobs:
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Install dependencies
@@ -42,7 +42,7 @@ jobs:
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v6
with:
name: test_fetched
path: test_preparation.txt
@@ -83,7 +83,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -109,7 +109,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v3
uses: actions/upload-artifact@v6
with:
name: ${{ matrix.modules }}_test_reports
path: reports
@@ -138,7 +138,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -164,7 +164,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -31,9 +31,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -51,9 +51,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -108,7 +108,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -153,7 +153,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
@@ -185,7 +185,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
@@ -236,7 +236,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -273,7 +273,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_main_test_reports
path: reports

View File

@@ -32,9 +32,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -52,9 +52,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies
@@ -83,7 +83,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -100,7 +100,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -120,7 +120,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -170,7 +170,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -193,7 +193,7 @@ jobs:
module: [models, schedulers, lora, others]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -239,7 +239,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -255,7 +255,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -287,7 +287,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports

View File

@@ -18,9 +18,9 @@ jobs:
check_torch_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -29,7 +29,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -46,7 +46,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -66,7 +66,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -98,7 +98,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -120,7 +120,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -155,7 +155,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -172,7 +172,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -199,7 +199,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -216,7 +216,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -240,7 +240,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_xformers_test_reports
path: reports
@@ -256,7 +256,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -286,7 +286,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports

View File

@@ -54,7 +54,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -88,7 +88,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -65,7 +65,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pr_torch_mps_test_reports
path: reports

View File

@@ -15,10 +15,10 @@ jobs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout Repo
uses: actions/checkout@v3
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: '3.8'
@@ -40,12 +40,12 @@ jobs:
steps:
- name: Checkout Repo
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
uses: actions/setup-python@v4
uses: actions/setup-python@v6
with:
python-version: "3.8"

View File

@@ -27,7 +27,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -94,7 +94,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -116,7 +116,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -149,7 +149,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_cuda_${{ matrix.module }}_test_reports
path: reports
@@ -166,7 +166,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -205,7 +205,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -222,7 +222,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -247,7 +247,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_compile_test_reports
path: reports
@@ -264,7 +264,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -288,7 +288,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: torch_xformers_test_reports
path: reports
@@ -305,7 +305,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2
@@ -336,7 +336,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v6
with:
name: examples_test_reports
path: reports

View File

@@ -57,7 +57,7 @@ jobs:
shell: bash -e {0}
- name: Checkout PR branch
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
ref: refs/pull/${{ inputs.pr_number }}/head

View File

@@ -27,7 +27,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2

View File

@@ -35,7 +35,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v3
uses: actions/checkout@v6
with:
fetch-depth: 2

View File

@@ -15,10 +15,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v2
- uses: actions/checkout@v6
- name: Setup Python
uses: actions/setup-python@v1
uses: actions/setup-python@v6
with:
python-version: 3.8

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Secret Scanning

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: typos-action
uses: crate-ci/typos@v1.12.4

View File

@@ -15,7 +15,7 @@ jobs:
shell: bash -l {0}
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v6
- name: Setup environment
run: |

View File

@@ -367,6 +367,8 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -443,6 +445,10 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -678,6 +684,8 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx2
title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Audio
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Audio
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Audio
[[autodoc]] AutoencoderKLLTX2Audio
- encode
- decode
- all

View File

@@ -0,0 +1,29 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Video
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Video
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Video
[[autodoc]] AutoencoderKLLTX2Video
- decode
- encode
- all

View File

@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
## FluxControlNetOutput
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput

View File

@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
## SparseControlNetOutput
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput

View File

@@ -0,0 +1,26 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTX2VideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTX2VideoTransformer3DModel
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTX2VideoTransformer3DModel
[[autodoc]] LTX2VideoTransformer3DModel

View File

@@ -0,0 +1,43 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX-2
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
- all
- __call__
## LTX2ImageToVideoPipeline
[[autodoc]] LTX2ImageToVideoPipeline
- all
- __call__
## LTX2LatentUpsamplePipeline
[[autodoc]] LTX2LatentUpsamplePipeline
- all
- __call__
## LTX2PipelineOutput
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput

View File

@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
<details>
<summary>Show example code</summary>
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
@@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
</details>
## LTXI2VLongMultiPromptPipeline
[[autodoc]] LTXI2VLongMultiPromptPipeline
- all
- __call__
## LTXPipeline
[[autodoc]] LTXPipeline

View File

@@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.

View File

@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.

View File

@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantzation_config=pipeline_quant_config,
quantization_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)

View File

@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```shell
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.

View File

@@ -44,6 +44,7 @@ import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -75,13 +76,16 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1219,7 +1227,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1263,17 +1275,42 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_cls = type(unwrap_model(transformer))
# make sure to pop weight so that corresponding model is not saved again
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1285,13 +1322,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1507,6 +1551,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1536,6 +1595,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1777,7 +1838,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1836,15 +1897,41 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -43,6 +43,7 @@ import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -74,13 +75,16 @@ from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -339,7 +346,7 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=True,
required=False,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -827,15 +835,28 @@ class DreamBoothDataset(Dataset):
dest_image = self.cond_images[i]
image_width, image_height = dest_image.size
if image_width * image_height > 1024 * 1024:
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
image_width, image_height = dest_image.size
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
dest_image = Flux2ImageProcessor.image_processor.preprocess(
image_processor = Flux2ImageProcessor()
dest_image = image_processor.preprocess(
dest_image, height=image_height, width=image_width, resize_mode="crop"
)
# Convert back to PIL
dest_image = dest_image.squeeze(0)
if dest_image.min() < 0:
dest_image = (dest_image + 1) / 2
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
if dest_image.shape[0] == 1:
# Gray scale image
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
else:
# RGB scale image: (C, H, W) -> (H, W, C)
dest_image = TF.to_pil_image(dest_image)
dest_image = exif_transpose(dest_image)
if not dest_image.mode == "RGB":
@@ -1156,7 +1177,11 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
transformer.to(**transformer_to_kwargs)
is_fsdp = accelerator.state.fsdp_plugin is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1200,17 +1225,42 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_cls = type(unwrap_model(transformer))
# make sure to pop weight so that corresponding model is not saved again
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1222,13 +1272,20 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
while len(models) > 0:
model = models.pop()
if not is_fsdp:
while len(models) > 0:
model = models.pop()
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1419,9 +1476,9 @@ def main(args):
args.instance_prompt, text_encoding_pipeline
)
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.validation_prompt is not None:
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.remote_text_encoder:
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
else:
@@ -1430,6 +1487,21 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1461,6 +1533,8 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1700,7 +1774,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
if accelerator.is_main_process or is_fsdp:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1759,15 +1833,41 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -0,0 +1,886 @@
import argparse
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 2,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 16,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
ckpt_path = filename
_, ext = os.path.splitext(ckpt_path)
if ext in [".safetensors", ".sft"]:
state_dict = safetensors.torch.load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
return state_dict
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
)
parser.add_argument(
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
help="Version of the LTX 2.0 model",
)
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
parser.add_argument(
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
)
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
)
parser.add_argument(
"--upsample_pipeline",
action="store_true",
help="Whether to save a latent upsampling pipeline",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
def main(args):
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any(
[
args.vae,
args.audio_vae,
args.dit,
args.vocoder,
args.text_encoder,
args.full_pipeline,
args.upsample_pipeline,
]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline or args.upsample_pipeline:
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None:
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
if not args.full_pipeline and not args.upsample_pipeline:
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.upsample_pipeline:
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
pipe.save_pretrained(
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -193,6 +193,8 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -236,6 +238,7 @@ else:
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -353,6 +356,7 @@ else:
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
"LCMScheduler",
"LTXEulerAncestralRFScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SASolverScheduler",
@@ -417,6 +421,8 @@ else:
"QwenImageEditModularPipeline",
"QwenImageEditPlusAutoBlocks",
"QwenImageEditPlusModularPipeline",
"QwenImageLayeredAutoBlocks",
"QwenImageLayeredModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
@@ -537,7 +543,11 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
@@ -937,6 +947,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -980,6 +992,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1088,6 +1101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
LTXEulerAncestralRFScheduler,
PNDMScheduler,
RePaintScheduler,
SASolverScheduler,
@@ -1135,6 +1149,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageLayeredAutoBlocks,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
@@ -1251,7 +1267,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,

View File

@@ -41,6 +41,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -104,6 +106,7 @@ if is_torch_available():
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -153,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -212,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -1106,6 +1106,51 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
return out
# Not implemented yet.
def _npu_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
# ===== Context parallel =====
@@ -1420,6 +1465,7 @@ def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1427,6 +1473,9 @@ def _flash_attention(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
if _parallel_config is None:
out = flash_attn_func(
q=query,
@@ -1469,6 +1518,7 @@ def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1476,6 +1526,9 @@ def _flash_attention_hub(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
@@ -1612,11 +1665,15 @@ def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
@@ -1636,6 +1693,7 @@ def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
@@ -1646,6 +1704,8 @@ def _flash_attention_3_hub(
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
@@ -1785,12 +1845,16 @@ def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
@@ -2028,6 +2092,7 @@ def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -2035,6 +2100,9 @@ def _native_flash_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
lse = None
if _parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2108,34 +2176,52 @@ def _native_math_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
if _parallel_config is None:
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_npu_attention_forward_op,
backward_op=_npu_attention_backward_op,
_parallel_config=_parallel_config,
)
return out
@@ -2148,10 +2234,13 @@ def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for XLA attention")
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2175,11 +2264,14 @@ def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
if _parallel_config is None:
out = sageattn(
@@ -2223,11 +2315,14 @@ def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
if _parallel_config is None:
@@ -2309,11 +2404,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
@@ -2333,11 +2431,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
@@ -2357,11 +2458,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
@@ -2381,11 +2485,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,

View File

@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,804 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
LATENT_DOWNSAMPLE_FACTOR = 4
class LTX2AudioCausalConv2d(nn.Module):
"""
A causal 2D convolution that pads asymmetrically along the causal axis.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
if self.causality_axis == "none":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis in {"width", "width-compatibility"}:
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis == "height":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
self.padding = padding
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding)
return self.conv(x)
class LTX2AudioPixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
class LTX2AudioAttnBlock(nn.Module):
def __init__(
self,
in_channels: int,
norm_type: str = "group",
) -> None:
super().__init__()
self.in_channels = in_channels
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = self.norm(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
batch, channels, height, width = q.shape
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
k = k.reshape(batch, channels, height * width).contiguous()
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
attn = torch.nn.functional.softmax(attn, dim=2)
v = v.reshape(batch, channels, height * width)
attn = attn.permute(0, 2, 1).contiguous()
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
h_ = self.proj_out(h_)
return x + h_
class LTX2AudioResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: str = "group",
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if norm_type == "group":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.non_linearity = nn.SiLU()
if causality_axis is not None:
self.conv1 = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if norm_type == "group":
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.dropout = nn.Dropout(dropout)
if causality_axis is not None:
self.conv2 = LTX2AudioCausalConv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
if causality_axis is not None:
self.conv_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
if causality_axis is not None:
self.nin_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
return x + h
class LTX2AudioDownsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
if self.causality_axis == "none":
pad = (0, 1, 0, 1)
elif self.causality_axis == "width":
pad = (2, 0, 0, 1)
elif self.causality_axis == "height":
pad = (0, 1, 2, 0)
elif self.causality_axis == "width-compatibility":
pad = (1, 0, 0, 1)
else:
raise ValueError(
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
f" and `width-compatibility`."
)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# with_conv=False implies that causality_axis is "none"
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class LTX2AudioUpsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
if causality_axis is not None:
self.conv = LTX2AudioCausalConv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if self.causality_axis is None or self.causality_axis == "none":
pass
elif self.causality_axis == "height":
x = x[:, :, 1:, :]
elif self.causality_axis == "width":
x = x[:, :, :, 1:]
elif self.causality_axis == "width-compatibility":
pass
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
"""
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
):
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self._patch_size = (1, patch_size, patch_size)
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
batch, channels, time, freq = audio_latents.shape
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
batch, time, _ = audio_latents.shape
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
class LTX2AudioEncoder(nn.Module):
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels
base_resolution = resolution
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution
for level in range(self.num_resolutions):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != self.num_resolutions - 1:
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
curr_res = curr_res // 2
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
final_block_channels = block_in
z_channels = 2 * latent_channels if double_z else latent_channels
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
self.non_linearity = nn.SiLU()
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
hidden_states = self.conv_in(hidden_states)
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx, block in enumerate(stage.block):
hidden_states = block(hidden_states, temb=None)
if stage.attn:
hidden_states = stage.attn[block_idx](hidden_states)
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
hidden_states = stage.downsample(hidden_states)
hidden_states = self.mid.block_1(hidden_states, temb=None)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb=None)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class LTX2AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
convolutions.
"""
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = LTX2AudioAudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.non_linearity = nn.SiLU()
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
self.up = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
for level in reversed(range(self.num_resolutions)):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks + 1):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != 0:
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
curr_res *= 2
self.up.insert(0, stage)
final_block_channels = block_in
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
_, _, frames, mel_bins = sample.shape
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_channels = self.out_ch
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
hidden_features = self.conv_in(sample)
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
hidden_features = block(hidden_features, temb=None)
if stage.attn:
hidden_features = stage.attn[block_idx](hidden_features)
if level != 0 and hasattr(stage, "upsample"):
hidden_features = stage.upsample(hidden_features)
if self.give_pre_end:
return hidden_features
hidden = self.norm_out(hidden_features)
hidden = self.non_linearity(hidden)
decoded_output = self.conv_out(hidden)
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
_, _, current_time, current_freq = decoded_output.shape
target_time = target_frames
target_freq = target_mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
LTX2 audio VAE for encoding and decoding audio latent representations.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_channels: int = 128,
output_channels: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
norm_type: str = "pixel",
causality_axis: Optional[str] = "height",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
) -> None:
super().__init__()
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
if causality_axis not in supported_causality_axes:
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.encoder = LTX2AudioEncoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
double_z=double_z,
)
self.decoder = LTX2AudioDecoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
)
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True):
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -1,115 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate
from .controlnets.controlnet import ( # noqa
ControlNetConditioningEmbedding,
ControlNetModel,
ControlNetOutput,
zero_module,
)
class ControlNetOutput(ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class ControlNetModel(ControlNetModel):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
addition_embed_type_num_heads=addition_embed_type_num_heads,
)
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -1,70 +0,0 @@
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from ..utils import deprecate, logging
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxControlNetOutput(FluxControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class FluxControlNetModel(FluxControlNetModel):
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
super().__init__(
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
pooled_projection_dim=pooled_projection_dim,
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
num_mode=num_mode,
conditioning_embedding_channels=conditioning_embedding_channels,
)
class FluxMultiControlNetModel(FluxMultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -1,68 +0,0 @@
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import deprecate, logging
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SD3ControlNetOutput(SD3ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SD3ControlNetModel(SD3ControlNetModel):
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
):
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
super().__init__(
sample_size=sample_size,
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
caption_projection_dim=caption_projection_dim,
pooled_projection_dim=pooled_projection_dim,
out_channels=out_channels,
pos_embed_max_size=pos_embed_max_size,
extra_conditioning_channels=extra_conditioning_channels,
)
class SD3MultiControlNetModel(SD3MultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -1,116 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate, logging
from .controlnets.controlnet_sparsectrl import ( # noqa
SparseControlNetConditioningEmbedding,
SparseControlNetModel,
SparseControlNetOutput,
zero_module,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SparseControlNetOutput(SparseControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
deprecate(
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
)
super().__init__(*args, **kwargs)
class SparseControlNetModel(SparseControlNetModel):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockMotion",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 768,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
controlnet_conditioning_channel_order: str = "rgb",
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
concat_conditioning_mask: bool = True,
use_simplified_condition_embedding: bool = True,
):
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
motion_max_seq_length=motion_max_seq_length,
motion_num_attention_heads=motion_num_attention_heads,
concat_conditioning_mask=concat_conditioning_mask,
use_simplified_condition_embedding=use_simplified_condition_embedding,
)

View File

@@ -47,6 +47,7 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
logger = logging.get_logger(__name__)
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
low_cpu_mem_usage=low_cpu_mem_usage,
)
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
tqdm_kwargs["disable"] = True
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
with logging.tqdm(**tqdm_kwargs) as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()

View File

@@ -59,11 +59,8 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.hub_utils import (
PushToHubMixin,
load_or_create_model_card,
populate_model_card,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
from .model_loading_utils import (
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
else:
shard_files = resolved_model_file
if len(resolved_model_file) > 1:
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
shard_tqdm_kwargs["disable"] = True
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
for shard_file in shard_files:
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)

View File

@@ -35,6 +35,7 @@ if is_torch_available():
from .transformer_kandinsky import Kandinsky5Transformer3DModel
from .transformer_longcat_image import LongCatImageTransformer2DModel
from .transformer_ltx import LTXVideoTransformer3DModel
from .transformer_ltx2 import LTX2VideoTransformer3DModel
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel

View File

@@ -22,7 +22,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
@@ -835,14 +835,8 @@ class Flux2Transformer2DModel(
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
if is_torch_npu_available():
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
else:
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
image_rotary_emb = self.pos_embed(img_ids)
text_rotary_emb = self.pos_embed(txt_ids)
concat_rotary_emb = (
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),

View File

@@ -165,9 +165,8 @@ class Kandinsky5TimeEmbeddings(nn.Module):
self.activation = nn.SiLU()
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, time):
args = torch.outer(time, self.freqs.to(device=time.device))
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
return time_embed
@@ -269,7 +268,6 @@ class Kandinsky5Modulation(nn.Module):
self.out_layer.weight.data.zero_()
self.out_layer.bias.data.zero_()
@torch.autocast(device_type="cuda", dtype=torch.float32)
def forward(self, x):
return self.out_layer(self.activation(x))
@@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel(
"Kandinsky5TransformerEncoderBlock",
"Kandinsky5TransformerDecoderBlock",
]
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
_supports_gradient_checkpointing = True
@register_to_config

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import is_torch_npu_available, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:

File diff suppressed because it is too large Load Diff

View File

@@ -21,7 +21,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import is_torch_npu_available, logging
from ...utils import logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -530,11 +530,7 @@ class OvisImageTransformer2DModel(
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
if is_torch_npu_available():
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
else:
image_rotary_emb = self.pos_embed(ids)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:

View File

@@ -63,6 +63,8 @@ else:
"QwenImageEditAutoBlocks",
"QwenImageEditPlusModularPipeline",
"QwenImageEditPlusAutoBlocks",
"QwenImageLayeredModularPipeline",
"QwenImageLayeredAutoBlocks",
]
_import_structure["z_image"] = [
"ZImageAutoBlocks",
@@ -96,6 +98,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageLayeredAutoBlocks,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
)
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline

View File

@@ -121,7 +121,7 @@ class FluxTextInputStep(ModularPipelineBlocks):
return components, state
# Adapted from `QwenImageInputsDynamicStep`
# Adapted from `QwenImageAdditionalInputsStep`
class FluxInputsDynamicStep(ModularPipelineBlocks):
model_name = "flux"

View File

@@ -168,6 +168,14 @@ class MellonParam:
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
)
@classmethod
def num_frames(cls, default: int = 81) -> "MellonParam":
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
@classmethod
def videos(cls) -> "MellonParam":
return cls(name="videos", label="Videos", type="video", display="output")
@classmethod
def vae(cls) -> "MellonParam":
"""

View File

@@ -62,6 +62,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
("qwenimage", "QwenImageModularPipeline"),
("qwenimage-edit", "QwenImageEditModularPipeline"),
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
("z-image", "ZImageModularPipeline"),
]
)
@@ -231,7 +232,7 @@ class BlockState:
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"""
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
LoopSequentialPipelineBlocks
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
@@ -527,9 +528,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
)
class AutoPipelineBlocks(ModularPipelineBlocks):
class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the inputs.
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
`select_block` method to define the logic for selecting the block.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -539,12 +541,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
Attributes:
block_classes: List of block classes to be used
block_names: List of prefixes for each block
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
"""
block_classes = []
block_names = []
block_trigger_inputs = []
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
def __init__(self):
sub_blocks = InsertableDict()
@@ -554,26 +557,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
else:
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
if not (len(self.block_classes) == len(self.block_names)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
)
default_blocks = [t for t in self.block_trigger_inputs if t is None]
# can only have 1 or 0 default block, and has to put in the last
# the order of blocks matters here because the first block with matching trigger will be dispatched
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
# as long as mask is provided, it is inpaint; if only image is provided, it is img2img
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
if self.default_block_name is not None and self.default_block_name not in self.block_names:
raise ValueError(
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
"in block_trigger_inputs."
f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
)
# Map trigger inputs to block objects
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
@property
def model_name(self):
return next(iter(self.sub_blocks.values())).model_name
@@ -602,8 +594,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
@property
def required_inputs(self) -> List[str]:
if None not in self.block_trigger_inputs:
# no default block means this conditional block can be skipped entirely
if self.default_block_name is None:
return []
first_block = next(iter(self.sub_blocks.values()))
required_by_all = set(getattr(first_block, "required_inputs", set()))
@@ -614,7 +608,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return list(required_by_all)
# YiYi TODO: add test for this
@property
def inputs(self) -> List[Tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
@@ -639,36 +632,9 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
# Find default block first (if any)
block = self.trigger_to_block_map.get(None)
for input_name in self.block_trigger_inputs:
if input_name is not None and state.get(input_name) is not None:
block = self.trigger_to_block_map[input_name]
break
if block is None:
logger.info(f"skipping auto block: {self.__class__.__name__}")
return pipeline, state
try:
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
return block(pipeline, state)
except Exception as e:
error_msg = (
f"\nError in block: {block.__class__.__name__}\n"
f"Error details: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
)
logger.error(error_msg)
raise
def _get_trigger_inputs(self):
def _get_trigger_inputs(self) -> set:
"""
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
block_trigger_inputs values
Returns a set of all unique trigger input values found in this block and nested blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -676,9 +642,8 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
if blocks is not None:
for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block)
# Check if current block has block_trigger_inputs
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -688,15 +653,57 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
return trigger_values
trigger_inputs = set(self.block_trigger_inputs)
trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
# Start with this block's block_trigger_inputs
all_triggers = {t for t in self.block_trigger_inputs if t is not None}
# Add nested triggers
all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
return trigger_inputs
return all_triggers
@property
def trigger_inputs(self):
"""All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
def select_block(self, **kwargs) -> Optional[str]:
"""
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
for selecting the block.
Args:
**kwargs: Trigger input names and their values from the state.
Returns:
Optional[str]: The name of the block to run, or None to use default/skip.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
block_name = self.select_block(**trigger_kwargs)
if block_name is None:
block_name = self.default_block_name
if block_name is None:
logger.info(f"skipping conditional block: {self.__class__.__name__}")
return pipeline, state
block = self.sub_blocks[block_name]
try:
logger.info(f"Running block: {block.__class__.__name__}")
return block(pipeline, state)
except Exception as e:
error_msg = (
f"\nError in block: {block.__class__.__name__}\n"
f"Error details: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
)
logger.error(error_msg)
raise
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -708,7 +715,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -729,31 +736,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
expected_configs = getattr(self, "expected_configs", [])
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
# Blocks section - moved to the end with simplified format
# Blocks section
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
# Get trigger input for this block
trigger = None
if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name)
# Format the trigger info
if trigger is None:
trigger_str = "[default]"
elif isinstance(trigger, (list, tuple)):
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
else:
trigger_str = f"[trigger: {trigger}]"
# For AutoPipelineBlocks, add bullet points
blocks_str += f"{name} {trigger_str} ({block.__class__.__name__})\n"
if name == self.default_block_name:
addtional_str = " [default]"
else:
# For SequentialPipelineBlocks, show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
addtional_str = ""
blocks_str += f" {name}{addtional_str} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split("\n")
indented_desc = desc_lines[0]
if len(desc_lines) > 1:
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
block_desc_lines = block.description.split("\n")
indented_desc = block_desc_lines[0]
if len(block_desc_lines) > 1:
indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
blocks_str += f" Description: {indented_desc}\n\n"
# Build the representation with conditional sections
@@ -784,6 +780,35 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
)
class AutoPipelineBlocks(ConditionalPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
"""
def __init__(self):
super().__init__()
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
@property
def default_block_name(self) -> Optional[str]:
"""Derive default_block_name from block_trigger_inputs (None entry)."""
if None in self.block_trigger_inputs:
idx = self.block_trigger_inputs.index(None)
return self.block_names[idx]
return None
def select_block(self, **kwargs) -> Optional[str]:
"""Select block based on which trigger input is present (not None)."""
for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
if trigger_input is not None and kwargs.get(trigger_input) is not None:
return block_name
return None
class SequentialPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
@@ -885,7 +910,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
# ConditionalPipelineBlocks without default can be skipped
should_add_outputs = False
if should_add_outputs:
@@ -948,8 +974,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def _get_trigger_inputs(self):
"""
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
block_trigger_inputs values
Returns a set of all unique trigger input values found in the blocks.
"""
def fn_recursive_get_trigger(blocks):
@@ -957,9 +982,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
if blocks is not None:
for name, block in blocks.items():
# Check if current block has trigger inputs(i.e. auto block)
# Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
# Add all non-None values from the trigger inputs list
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
# If block has sub_blocks, recursively check them
@@ -975,82 +999,84 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
def trigger_inputs(self):
return self._get_trigger_inputs()
def _traverse_trigger_blocks(self, trigger_inputs):
# Convert trigger_inputs to a set for easier manipulation
active_triggers = set(trigger_inputs)
def _traverse_trigger_blocks(self, active_inputs):
"""
Traverse blocks and select which ones would run given the active inputs.
def fn_recursive_traverse(block, block_name, active_triggers):
Args:
active_inputs: Dict of input names to values that are "present"
Returns:
OrderedDict of block_name -> block that would execute
"""
def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
# sequential(include loopsequential) or PipelineBlock
if not hasattr(block, "block_trigger_inputs"):
if block.sub_blocks:
# sequential or LoopSequentialPipelineBlocks (keep traversing)
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
if isinstance(block, ConditionalPipelineBlocks):
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
selected_block_name = block.select_block(**trigger_kwargs)
if selected_block_name is None:
selected_block_name = block.default_block_name
if selected_block_name is None:
return result_blocks
selected_block = block.sub_blocks[selected_block_name]
if selected_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
# PipelineBlock
result_blocks[block_name] = block
# Add this block's output names to active triggers if defined
if hasattr(block, "outputs"):
active_triggers.update(out.name for out in block.outputs)
result_blocks[block_name] = selected_block
if hasattr(selected_block, "outputs"):
for out in selected_block.outputs:
active_inputs[out.name] = True
return result_blocks
# auto
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
if block.sub_blocks:
for sub_block_name, sub_block in block.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
else:
# Find first block_trigger_input that matches any value in our active_triggers
this_block = None
for trigger_input in block.block_trigger_inputs:
if trigger_input is not None and trigger_input in active_triggers:
this_block = block.trigger_to_block_map[trigger_input]
break
# If no matches found, try to get the default (None) block
if this_block is None and None in block.block_trigger_inputs:
this_block = block.trigger_to_block_map[None]
if this_block is not None:
# sequential/auto (keep traversing)
if this_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
else:
# PipelineBlock
result_blocks[block_name] = this_block
# Add this block's output names to active triggers if defined
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
if hasattr(this_block, "outputs"):
active_triggers.update(out.name for out in this_block.outputs)
result_blocks[block_name] = block
if hasattr(block, "outputs"):
for out in block.outputs:
active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
def get_execution_blocks(self, *trigger_inputs):
trigger_inputs_all = self.trigger_inputs
def get_execution_blocks(self, **kwargs):
"""
Get the blocks that would execute given the specified inputs.
if trigger_inputs is not None:
if not isinstance(trigger_inputs, (list, tuple, set)):
trigger_inputs = [trigger_inputs]
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
if invalid_inputs:
logger.warning(
f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
)
trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Pass any inputs that would be non-None at runtime.
if trigger_inputs is None:
if None in trigger_inputs_all:
trigger_inputs = [None]
else:
trigger_inputs = [trigger_inputs_all[0]]
blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
Example:
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
image=image)
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
"""
# Filter out None values
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
def __repr__(self):
@@ -1067,7 +1093,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self.trigger_inputs if t is not None)
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -1091,22 +1117,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# Blocks section - moved to the end with simplified format
blocks_str = " Sub-Blocks:\n"
for i, (name, block) in enumerate(self.sub_blocks.items()):
# Get trigger input for this block
trigger = None
if hasattr(self, "block_to_trigger_map"):
trigger = self.block_to_trigger_map.get(name)
# Format the trigger info
if trigger is None:
trigger_str = "[default]"
elif isinstance(trigger, (list, tuple)):
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
else:
trigger_str = f"[trigger: {trigger}]"
# For AutoPipelineBlocks, add bullet points
blocks_str += f"{name} {trigger_str} ({block.__class__.__name__})\n"
else:
# For SequentialPipelineBlocks, show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# show execution order
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
# Add block description
desc_lines = block.description.split("\n")
@@ -1230,15 +1242,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
if inp.name not in outputs and inp not in inputs:
inputs.append(inp)
# Only add outputs if the block cannot be skipped
should_add_outputs = True
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
should_add_outputs = False
if should_add_outputs:
# Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
# Add this block's outputs
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
outputs.update(block_intermediate_outputs)
for input_param in inputs:
if input_param.name in self.required_inputs:
@@ -1295,6 +1301,14 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
sub_blocks[block_name] = block
self.sub_blocks = sub_blocks
# Validate that sub_blocks are only leaf blocks
for block_name, block in self.sub_blocks.items():
if block.sub_blocks:
raise ValueError(
f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
)
@classmethod
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
"""

View File

@@ -1,661 +0,0 @@
import json
import logging
import os
from pathlib import Path
from typing import List, Optional, Tuple, Union
import numpy as np
import PIL
import torch
from ..configuration_utils import ConfigMixin
from ..image_processor import PipelineImageInput
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
from .modular_pipeline_utils import InputParam
logger = logging.getLogger(__name__)
# YiYi Notes: this is actually for SDXL, put it here for now
SDXL_INPUTS_SCHEMA = {
"prompt": InputParam(
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
),
"prompt_2": InputParam(
"prompt_2",
type_hint=Union[str, List[str]],
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
),
"negative_prompt": InputParam(
"negative_prompt",
type_hint=Union[str, List[str]],
description="The prompt or prompts not to guide the image generation",
),
"negative_prompt_2": InputParam(
"negative_prompt_2",
type_hint=Union[str, List[str]],
description="The negative prompt or prompts for text_encoder_2",
),
"cross_attention_kwargs": InputParam(
"cross_attention_kwargs",
type_hint=Optional[dict],
description="Kwargs dictionary passed to the AttentionProcessor",
),
"clip_skip": InputParam(
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
),
"image": InputParam(
"image",
type_hint=PipelineImageInput,
required=True,
description="The image(s) to modify for img2img or inpainting",
),
"mask_image": InputParam(
"mask_image",
type_hint=PipelineImageInput,
required=True,
description="Mask image for inpainting, white pixels will be repainted",
),
"generator": InputParam(
"generator",
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
description="Generator(s) for deterministic generation",
),
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
"num_images_per_prompt": InputParam(
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
),
"timesteps": InputParam(
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
),
"sigmas": InputParam(
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
),
"denoising_end": InputParam(
"denoising_end",
type_hint=Optional[float],
description="Fraction of denoising process to complete before termination",
),
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
"strength": InputParam(
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
),
"denoising_start": InputParam(
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
),
"latents": InputParam(
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
),
"padding_mask_crop": InputParam(
"padding_mask_crop",
type_hint=Optional[Tuple[int, int]],
description="Size of margin in crop for image and mask",
),
"original_size": InputParam(
"original_size",
type_hint=Optional[Tuple[int, int]],
description="Original size of the image for SDXL's micro-conditioning",
),
"target_size": InputParam(
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
),
"negative_original_size": InputParam(
"negative_original_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on image resolution",
),
"negative_target_size": InputParam(
"negative_target_size",
type_hint=Optional[Tuple[int, int]],
description="Negative conditioning based on target resolution",
),
"crops_coords_top_left": InputParam(
"crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Top-left coordinates for SDXL's micro-conditioning",
),
"negative_crops_coords_top_left": InputParam(
"negative_crops_coords_top_left",
type_hint=Tuple[int, int],
default=(0, 0),
description="Negative conditioning crop coordinates",
),
"aesthetic_score": InputParam(
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
),
"negative_aesthetic_score": InputParam(
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
),
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
"output_type": InputParam(
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
),
"ip_adapter_image": InputParam(
"ip_adapter_image",
type_hint=PipelineImageInput,
required=True,
description="Image(s) to be used as IP adapter",
),
"control_image": InputParam(
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
),
"control_guidance_start": InputParam(
"control_guidance_start",
type_hint=Union[float, List[float]],
default=0.0,
description="When ControlNet starts applying",
),
"control_guidance_end": InputParam(
"control_guidance_end",
type_hint=Union[float, List[float]],
default=1.0,
description="When ControlNet stops applying",
),
"controlnet_conditioning_scale": InputParam(
"controlnet_conditioning_scale",
type_hint=Union[float, List[float]],
default=1.0,
description="Scale factor for ControlNet outputs",
),
"guess_mode": InputParam(
"guess_mode",
type_hint=bool,
default=False,
description="Enables ControlNet encoder to recognize input without prompts",
),
"control_mode": InputParam(
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
),
}
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
"prompt_embeds": InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
required=True,
description="Text embeddings used to guide image generation",
),
"negative_prompt_embeds": InputParam(
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
),
"pooled_prompt_embeds": InputParam(
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
),
"negative_pooled_prompt_embeds": InputParam(
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
),
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
"preprocess_kwargs": InputParam(
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
),
"latents": InputParam(
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
),
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
"num_inference_steps": InputParam(
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
),
"latent_timestep": InputParam(
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
),
"image_latents": InputParam(
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
),
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
"masked_image_latents": InputParam(
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
),
"add_time_ids": InputParam(
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
),
"negative_add_time_ids": InputParam(
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
),
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
"ip_adapter_embeds": InputParam(
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
),
"negative_ip_adapter_embeds": InputParam(
"negative_ip_adapter_embeds",
type_hint=List[torch.Tensor],
description="Negative image embeddings for IP-Adapter",
),
"images": InputParam(
"images",
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
required=True,
description="Generated images",
),
}
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
DEFAULT_PARAM_MAPS = {
"prompt": {
"label": "Prompt",
"type": "string",
"default": "a bear sitting in a chair drinking a milkshake",
"display": "textarea",
},
"negative_prompt": {
"label": "Negative Prompt",
"type": "string",
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
"display": "textarea",
},
"num_inference_steps": {
"label": "Steps",
"type": "int",
"default": 25,
"min": 1,
"max": 1000,
},
"seed": {
"label": "Seed",
"type": "int",
"default": 0,
"min": 0,
"display": "random",
},
"width": {
"label": "Width",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"height": {
"label": "Height",
"type": "int",
"display": "text",
"default": 1024,
"min": 8,
"max": 8192,
"step": 8,
"group": "dimensions",
},
"images": {
"label": "Images",
"type": "image",
"display": "output",
},
"image": {
"label": "Image",
"type": "image",
"display": "input",
},
}
DEFAULT_TYPE_MAPS = {
"int": {
"type": "int",
"default": 0,
"min": 0,
},
"float": {
"type": "float",
"default": 0.0,
"min": 0.0,
},
"str": {
"type": "string",
"default": "",
},
"bool": {
"type": "boolean",
"default": False,
},
"image": {
"type": "image",
},
}
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
DEFAULT_CATEGORY = "Modular Diffusers"
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
DEFAULT_PARAMS_GROUPS_KEYS = {
"text_encoders": ["text_encoder", "tokenizer"],
"ip_adapter_embeds": ["ip_adapter_embeds"],
"prompt_embeddings": ["prompt_embeds"],
}
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
"""
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
"""
if name is None:
return None
for group_name, group_keys in group_params_keys.items():
for group_key in group_keys:
if group_key in name:
return group_name
return None
class ModularNode(ConfigMixin):
"""
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
around a ModularPipelineBlocks object.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
config_name = "node_config.json"
@classmethod
def from_pretrained(
cls,
pretrained_model_name_or_path: str,
trust_remote_code: Optional[bool] = None,
**kwargs,
):
blocks = ModularPipelineBlocks.from_pretrained(
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
)
return cls(blocks, **kwargs)
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
self.blocks = blocks
if label is None:
label = self.blocks.__class__.__name__
# blocks param name -> mellon param name
self.name_mapping = {}
input_params = {}
# pass or create a default param dict for each input
# e.g. for prompt,
# prompt = {
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
# "label": "Prompt",
# "type": "string",
# "default": "a bear sitting in a chair drinking a milkshake",
# "display": "textarea"}
# if type is not specified, it'll be a "custom" param of its own type
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
for inp in inputs:
param = kwargs.pop(inp.name, None)
if param:
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
input_params[inp.name] = param
mellon_name = param.pop("name", inp.name)
if mellon_name != inp.name:
self.name_mapping[inp.name] = mellon_name
continue
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
continue
if inp.name in DEFAULT_PARAM_MAPS:
# first check if it's in the default param map, if so, directly use that
param = DEFAULT_PARAM_MAPS[inp.name].copy()
elif get_group_name(inp.name):
param = get_group_name(inp.name)
if inp.name not in self.name_mapping:
self.name_mapping[inp.name] = param
else:
# if not, check if it's in the SDXL input schema, if so,
# 1. use the type hint to determine the type
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
if inp.type_hint is not None:
type_str = str(inp.type_hint).lower()
else:
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
if type_key in type_str:
param = type_param.copy()
param["label"] = inp.name
param["display"] = "input"
break
else:
param = inp.name
# add the param dict to the inp_params dict
input_params[inp.name] = param
component_params = {}
for comp in self.blocks.expected_components:
param = kwargs.pop(comp.name, None)
if param:
component_params[comp.name] = param
mellon_name = param.pop("name", comp.name)
if mellon_name != comp.name:
self.name_mapping[comp.name] = mellon_name
continue
to_exclude = False
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
if exclude_key in comp.name:
to_exclude = True
break
if to_exclude:
continue
if get_group_name(comp.name):
param = get_group_name(comp.name)
if comp.name not in self.name_mapping:
self.name_mapping[comp.name] = param
elif comp.name in DEFAULT_MODEL_KEYS:
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
else:
param = comp.name
# add the param dict to the model_params dict
component_params[comp.name] = param
output_params = {}
if isinstance(self.blocks, SequentialPipelineBlocks):
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
else:
outputs = self.blocks.intermediate_outputs
for out in outputs:
param = kwargs.pop(out.name, None)
if param:
output_params[out.name] = param
mellon_name = param.pop("name", out.name)
if mellon_name != out.name:
self.name_mapping[out.name] = mellon_name
continue
if out.name in DEFAULT_PARAM_MAPS:
param = DEFAULT_PARAM_MAPS[out.name].copy()
param["display"] = "output"
else:
group_name = get_group_name(out.name)
if group_name:
param = group_name
if out.name not in self.name_mapping:
self.name_mapping[out.name] = param
else:
param = out.name
# add the param dict to the outputs dict
output_params[out.name] = param
if len(kwargs) > 0:
logger.warning(f"Unused kwargs: {kwargs}")
register_dict = {
"category": category,
"label": label,
"input_params": input_params,
"component_params": component_params,
"output_params": output_params,
"name_mapping": self.name_mapping,
}
self.register_to_config(**register_dict)
def setup(self, components_manager, collection=None):
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
self._components_manager = components_manager
@property
def mellon_config(self):
return self._convert_to_mellon_config()
def _convert_to_mellon_config(self):
node = {}
node["label"] = self.config.label
node["category"] = self.config.category
node_param = {}
for inp_name, inp_param in self.config.input_params.items():
if inp_name in self.name_mapping:
mellon_name = self.name_mapping[inp_name]
else:
mellon_name = inp_name
if isinstance(inp_param, str):
param = {
"label": inp_param,
"type": inp_param,
"display": "input",
}
else:
param = inp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
for comp_name, comp_param in self.config.component_params.items():
if comp_name in self.name_mapping:
mellon_name = self.name_mapping[comp_name]
else:
mellon_name = comp_name
if isinstance(comp_param, str):
param = {
"label": comp_param,
"type": comp_param,
"display": "input",
}
else:
param = comp_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
for out_name, out_param in self.config.output_params.items():
if out_name in self.name_mapping:
mellon_name = self.name_mapping[out_name]
else:
mellon_name = out_name
if isinstance(out_param, str):
param = {
"label": out_param,
"type": out_param,
"display": "output",
}
else:
param = out_param
if mellon_name not in node_param:
node_param[mellon_name] = param
else:
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
node["params"] = node_param
return node
def save_mellon_config(self, file_path):
"""
Save the Mellon configuration to a JSON file.
Args:
file_path (str or Path): Path where the JSON file will be saved
Returns:
Path: Path to the saved config file
"""
file_path = Path(file_path)
# Create directory if it doesn't exist
os.makedirs(file_path.parent, exist_ok=True)
# Create a combined dictionary with module definition and name mapping
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
# Save the config to file
with open(file_path, "w", encoding="utf-8") as f:
json.dump(config, f, indent=2)
logger.info(f"Mellon config and name mapping saved to {file_path}")
return file_path
@classmethod
def load_mellon_config(cls, file_path):
"""
Load a Mellon configuration from a JSON file.
Args:
file_path (str or Path): Path to the JSON file containing Mellon config
Returns:
dict: The loaded combined configuration containing 'module' and 'name_mapping'
"""
file_path = Path(file_path)
if not file_path.exists():
raise FileNotFoundError(f"Config file not found: {file_path}")
with open(file_path, "r", encoding="utf-8") as f:
config = json.load(f)
logger.info(f"Mellon config loaded from {file_path}")
return config
def process_inputs(self, **kwargs):
params_components = {}
for comp_name, comp_param in self.config.component_params.items():
logger.debug(f"component: {comp_name}")
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
if mellon_comp_name in kwargs:
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
comp = kwargs[mellon_comp_name].pop(comp_name)
else:
comp = kwargs.pop(mellon_comp_name)
if comp:
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
params_run = {}
for inp_name, inp_param in self.config.input_params.items():
logger.debug(f"input: {inp_name}")
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
if mellon_inp_name in kwargs:
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
inp = kwargs[mellon_inp_name].pop(inp_name)
else:
inp = kwargs.pop(mellon_inp_name)
if inp is not None:
params_run[inp_name] = inp
return_output_names = list(self.config.output_params.keys())
return params_components, params_run, return_output_names
def execute(self, **kwargs):
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
self.pipeline.update_components(**params_components)
output = self.pipeline(**params_run, output=return_output_names)
return output

View File

@@ -21,27 +21,27 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
_import_structure["modular_blocks_qwenimage"] = [
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"EDIT_AUTO_BLOCKS",
"EDIT_BLOCKS",
"EDIT_INPAINT_BLOCKS",
"EDIT_PLUS_AUTO_BLOCKS",
"EDIT_PLUS_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"QwenImageAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit"] = [
"EDIT_AUTO_BLOCKS",
"QwenImageEditAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
"EDIT_PLUS_AUTO_BLOCKS",
"QwenImageEditPlusAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_layered"] = [
"LAYERED_AUTO_BLOCKS",
"QwenImageLayeredAutoBlocks",
]
_import_structure["modular_pipeline"] = [
"QwenImageEditModularPipeline",
"QwenImageEditPlusModularPipeline",
"QwenImageModularPipeline",
"QwenImageLayeredModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -51,28 +51,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import (
QwenImageTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
from .modular_blocks_qwenimage import (
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
EDIT_AUTO_BLOCKS,
EDIT_BLOCKS,
EDIT_INPAINT_BLOCKS,
EDIT_PLUS_AUTO_BLOCKS,
EDIT_PLUS_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
TEXT2IMAGE_BLOCKS,
QwenImageAutoBlocks,
)
from .modular_blocks_qwenimage_edit import (
EDIT_AUTO_BLOCKS,
QwenImageEditAutoBlocks,
)
from .modular_blocks_qwenimage_edit_plus import (
EDIT_PLUS_AUTO_BLOCKS,
QwenImageEditPlusAutoBlocks,
)
from .modular_blocks_qwenimage_layered import (
LAYERED_AUTO_BLOCKS,
QwenImageLayeredAutoBlocks,
)
from .modular_pipeline import (
QwenImageEditModularPipeline,
QwenImageEditPlusModularPipeline,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
)
else:

View File

@@ -23,7 +23,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils.torch_utils import randn_tensor, unwrap_module
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
@@ -113,7 +113,9 @@ def get_timesteps(scheduler, num_inference_steps, strength):
return timesteps, num_inference_steps - t_start
# Prepare Latents steps
# ====================
# 1. PREPARE LATENTS
# ====================
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
@@ -207,6 +209,98 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
return components, state
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process"
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="layers", default=4),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="generator"),
InputParam(
name="batch_size",
required=True,
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
),
InputParam(
name="dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs, can be generated in input step.",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process",
),
]
@staticmethod
def check_inputs(height, width, vae_scale_factor):
if height is not None and height % (vae_scale_factor * 2) != 0:
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
if width is not None and width % (vae_scale_factor * 2) != 0:
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
height=block_state.height,
width=block_state.width,
vae_scale_factor=components.vae_scale_factor,
)
device = components._execution_device
batch_size = block_state.batch_size * block_state.num_images_per_prompt
# we can update the height and width here since it's used to generate the initial
block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
# VAE applies 8x compression on images but we must also account for packing which requires
# latent height and width to be divisible by 2.
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width)
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
self.set_block_state(state, block_state)
return components, state
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -351,7 +445,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
return components, state
# Set Timesteps steps
# ====================
# 2. SET TIMESTEPS
# ====================
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
@@ -420,6 +516,64 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
return components, state
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("num_inference_steps", default=50, type_hint=int),
InputParam("sigmas", type_hint=List[float]),
InputParam("image_latents", required=True, type_hint=torch.Tensor),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="timesteps", type_hint=torch.Tensor),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# Layered-specific mu calculation
base_seqlen = 256 * 256 / 16 / 16 # = 256
mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5
# Default sigmas if not provided
sigmas = (
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
if block_state.sigmas is None
else block_state.sigmas
)
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
components.scheduler,
block_state.num_inference_steps,
device,
sigmas=sigmas,
mu=mu,
)
components.scheduler.set_begin_index(0)
self.set_block_state(state, block_state)
return components, state
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -493,7 +647,9 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
return components, state
# other inputs for denoiser
# ====================
# 3. OTHER INPUTS FOR DENOISER
# ====================
## RoPE inputs for denoiser
@@ -522,6 +678,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
@@ -589,6 +746,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the images latents, used for RoPE calculation",
),
@@ -639,19 +797,64 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage-edit-plus"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
"Should be placed after prepare_latents step."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="image_height", required=True, type_hint=List[int]),
InputParam(name="image_width", required=True, type_hint=List[int]),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
kwargs_type="denoiser_input_fields",
type_hint=List[List[Tuple[int, int, int]]],
description="The shapes of the image latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
vae_scale_factor = components.vae_scale_factor
# Edit Plus: image_height and image_width are lists
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size
@@ -670,6 +873,87 @@ class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
return components, state
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return (
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="batch_size", required=True),
InputParam(name="layers", required=True),
InputParam(name="height", required=True),
InputParam(name="width", required=True),
InputParam(name="prompt_embeds_mask"),
InputParam(name="negative_prompt_embeds_mask"),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="img_shapes",
type_hint=List[List[Tuple[int, int, int]]],
kwargs_type="denoiser_input_fields",
description="The shapes of the image latents, used for RoPE calculation",
),
OutputParam(
name="txt_seq_lens",
type_hint=List[int],
kwargs_type="denoiser_input_fields",
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
),
OutputParam(
name="negative_txt_seq_lens",
type_hint=List[int],
kwargs_type="denoiser_input_fields",
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
),
OutputParam(
name="additional_t_cond",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="The additional t cond, used for RoPE calculation",
),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
# All shapes are the same for Layered
shape = (
1,
block_state.height // components.vae_scale_factor // 2,
block_state.width // components.vae_scale_factor // 2,
)
# layers+1 output shapes + 1 condition shape (all same)
block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size
# txt_seq_lens
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)
block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long)
self.set_block_state(state, block_state)
return components, state
## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"

View File

@@ -24,12 +24,13 @@ from ...models import AutoencoderKLQwenImage
from ...utils import logging
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
logger = logging.get_logger(__name__)
# after denoising loop (unpack latents)
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -71,6 +72,46 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
return components, state
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("height", required=True, type_hint=int),
InputParam("width", required=True, type_hint=int),
InputParam("layers", required=True, type_hint=int),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W)
block_state.latents = components.pachifier.unpack_latents(
block_state.latents,
block_state.height,
block_state.width,
block_state.layers,
components.vae_scale_factor,
)
self.set_block_state(state, block_state)
return components, state
# decode step
class QwenImageDecoderStep(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -135,6 +176,81 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
return components, state
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
model_name = "qwenimage-layered"
@property
def description(self) -> str:
return "Decode unpacked latents (B, C, layers+1, H, W) into layer images."
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("vae", AutoencoderKLQwenImage),
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict({"vae_scale_factor": 16}),
default_creation_method="from_config",
),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents", required=True, type_hint=torch.Tensor),
InputParam("output_type", default="pil", type_hint=str),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
]
@torch.no_grad()
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
latents = block_state.latents
# 1. VAE normalization
latents = latents.to(components.vae.dtype)
latents_mean = (
torch.tensor(components.vae.config.latents_mean)
.view(1, components.vae.config.z_dim, 1, 1, 1)
.to(latents.device, latents.dtype)
)
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
1, components.vae.config.z_dim, 1, 1, 1
).to(latents.device, latents.dtype)
latents = latents / latents_std + latents_mean
# 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W)
b, c, f, h, w = latents.shape
# 3. Remove first frame (composite), keep layers frames
latents = latents[:, :, 1:]
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
# 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W)
image = components.vae.decode(latents, return_dict=False)[0]
image = image.squeeze(2)
# 5. Postprocess - returns flat list of B*layers images
image = components.image_processor.postprocess(image, output_type=block_state.output_type)
# 6. Chunk into list per batch item
images = []
for bidx in range(b):
images.append(image[bidx * f : (bidx + 1) * f])
block_state.images = images
self.set_block_state(state, block_state)
return components, state
# postprocess the decoded images
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
model_name = "qwenimage"

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Tuple
import torch
@@ -28,7 +29,12 @@ from .modular_pipeline import QwenImageModularPipeline
logger = logging.get_logger(__name__)
# ====================
# 1. LOOP STEPS (run at each denoising step)
# ====================
# loop step:before denoiser
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -60,7 +66,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
model_name = "qwenimage-edit"
@property
def description(self) -> str:
@@ -185,6 +191,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
return components, block_state
# loop step:denoiser
class QwenImageLoopDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -253,6 +260,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
additional_cond_kwargs = {}
for field_name, field_value in block_state.denoiser_input_fields.items():
if field_name in transformer_args and field_name not in guider_inputs:
additional_cond_kwargs[field_name] = field_value
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
@@ -264,7 +278,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
@@ -284,7 +297,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
model_name = "qwenimage-edit"
@property
def description(self) -> str:
@@ -351,6 +364,13 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
),
}
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
additional_cond_kwargs = {}
for field_name, field_value in block_state.denoiser_input_fields.items():
if field_name in transformer_args and field_name not in guider_inputs:
additional_cond_kwargs[field_name] = field_value
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
guider_state = components.guider.prepare_inputs(guider_inputs)
@@ -362,7 +382,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
guider_state_batch.noise_pred = components.transformer(
hidden_states=block_state.latent_model_input,
timestep=block_state.timestep / 1000,
img_shapes=block_state.img_shapes,
attention_kwargs=block_state.attention_kwargs,
return_dict=False,
**cond_kwargs,
@@ -384,6 +403,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
return components, block_state
# loop step:after denoiser
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "qwenimage"
@@ -481,6 +501,9 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
return components, block_state
# ====================
# 2. DENOISE LOOP WRAPPER: define the denoising loop logic
# ====================
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "qwenimage"
@@ -537,8 +560,15 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
return components, state
# composing the denoising loops
# ====================
# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps
# ====================
# Qwen Image (text2image, image2image)
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
@@ -559,8 +589,9 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
)
# composing the inpainting denoising loops
# Qwen Image (inpainting)
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopDenoiser,
@@ -583,8 +614,9 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
)
# composing the controlnet denoising loops
# Qwen Image (text2image, image2image) with controlnet
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
@@ -607,8 +639,9 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
)
# composing the controlnet denoising loops
# Qwen Image (inpainting) with controlnet
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage"
block_classes = [
QwenImageLoopBeforeDenoiser,
QwenImageLoopBeforeDenoiserControlNet,
@@ -639,8 +672,9 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
)
# composing the denoising loops
# Qwen Image Edit (image2image)
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
@@ -661,7 +695,9 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
)
# Qwen Image Edit (inpainting)
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
@@ -682,3 +718,26 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
" - `QwenImageLoopAfterDenoiserInpaint`\n"
"This block supports inpainting tasks for QwenImage Edit."
)
# Qwen Image Layered (image2image)
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
model_name = "qwenimage-layered"
block_classes = [
QwenImageEditLoopBeforeDenoiser,
QwenImageEditLoopDenoiser,
QwenImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
" - `QwenImageEditLoopBeforeDenoiser`\n"
" - `QwenImageEditLoopDenoiser`\n"
" - `QwenImageLoopAfterDenoiser`\n"
"This block supports QwenImage Layered."
)

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@ import torch
from ...models import QwenImageMultiControlNetModel
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
def repeat_tensor_to_batch_size(
@@ -221,37 +221,16 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
return components, state
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage: update height/width, expand batch, patchify."""
model_name = "qwenimage"
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to []. Examples: ["processed_mask_image"]
Examples:
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
# Configure to process multiple image latent inputs
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
)
"""
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
@@ -263,14 +242,12 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
" 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
@@ -279,11 +256,16 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
inputs = [
@@ -293,11 +275,9 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
@@ -306,26 +286,28 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
]
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
OutputParam(
name="image_height",
type_hint=int,
description="The image height calculated from the image latents dimension",
),
OutputParam(
name="image_width",
type_hint=int,
description="The image width calculated from the image latents dimension",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
# 1. Calculate height/width from latents and update if not provided
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
@@ -335,7 +317,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if not hasattr(block_state, "image_width"):
block_state.image_width = width
# 2. Patchify the image latent tensor
# 2. Patchify
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
@@ -354,7 +336,6 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
@@ -368,63 +349,270 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
return components, state
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
model_name = "qwenimage-edit-plus"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
summary_section = (
"Input processing step for Edit Plus that:\n"
" 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
" Height/width defaults to last image in the list."
)
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
OutputParam(
name="image_height",
type_hint=List[int],
description="The image heights calculated from the image latents dimension",
),
OutputParam(
name="image_width",
type_hint=List[int],
description="The image widths calculated from the image latents dimension",
),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# Each image latent can have different size in QwenImage Edit Plus.
is_list = isinstance(image_latent_tensor, list)
if not is_list:
image_latent_tensor = [image_latent_tensor]
image_heights = []
image_widths = []
packed_image_latent_tensors = []
for img_latent_tensor in image_latent_tensor:
for i, img_latent_tensor in enumerate(image_latent_tensor):
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
image_heights.append(height)
image_widths.append(width)
# 2. Patchify the image latent tensor
# 2. Patchify
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
# 3. Expand batch size
img_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_name=f"{image_latent_input_name}[{i}]",
input_tensor=img_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
packed_image_latent_tensors.append(img_latent_tensor)
# Concatenate all packed latents along dim=1
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
# Output lists of heights/widths
block_state.image_height = image_heights
block_state.image_width = image_widths
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
# Default height/width from last image
block_state.height = block_state.height or image_heights[-1]
block_state.width = block_state.width or image_widths[-1]
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
# YiYi TODO: support define config default component from the ModularPipeline level.
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
model_name = "qwenimage-layered"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
summary_section = (
"Input processing step for Layered that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
]
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
]
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="image_height",
type_hint=int,
description="The image height calculated from the image latents dimension",
),
OutputParam(
name="image_width",
type_hint=int,
description="The image width calculated from the image latents dimension",
),
OutputParam(name="height", type_hint=int, description="The height of the image output"),
OutputParam(name="width", type_hint=int, description="The width of the image output"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents and update if not provided
# Layered latents are (B, layers, C, H, W)
height = image_latent_tensor.shape[3] * components.vae_scale_factor
width = image_latent_tensor.shape[4] * components.vae_scale_factor
block_state.height = height
block_state.width = width
if not hasattr(block_state, "image_height"):
block_state.image_height = height
if not hasattr(block_state, "image_width"):
block_state.image_width = width
# 2. Patchify with layered pachifier
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,469 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageControlNetBeforeDenoiserStep,
QwenImageCreateMaskLatentsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
QwenImageRoPEInputsStep,
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageInpaintProcessImagesOutputStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageControlNetDenoiseStep,
QwenImageDenoiseStep,
QwenImageInpaintControlNetDenoiseStep,
QwenImageInpaintDenoiseStep,
)
from .encoders import (
QwenImageControlNetVaeEncoderStep,
QwenImageInpaintProcessImagesInputStep,
QwenImageProcessImagesInputStep,
QwenImageTextEncoderStep,
QwenImageVaeEncoderStep,
)
from .inputs import (
QwenImageAdditionalInputsStep,
QwenImageControlNetInputsStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. VAE ENCODER
# ====================
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
" - Resizes the image to the target size, based on `height` and `width`.\n"
" - Processes and updates `image` and `mask_image`.\n"
" - Creates `image_latents`."
)
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
# Auto VAE encoder
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
+ " - if `mask_image` or `image` is not provided, step will be skipped."
)
# optional controlnet vae encoder
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageControlNetVaeEncoderStep]
block_names = ["controlnet"]
block_trigger_inputs = ["control_image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block.\n"
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
+ " - if `control_image` is not provided, step will be skipped."
)
# ====================
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
# assemble input steps
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
# assemble prepare latents steps
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@property
def description(self) -> str:
return (
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
" - Add noise to the image latents to create the latents input for the denoiser.\n"
" - Create the pachified latents `mask` based on the processedmask image.\n"
)
# assemble denoising steps
# Qwen Image (text2image)
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageRoPEInputsStep(),
QwenImageDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
# Qwen Image (inpainting)
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageInpaintPrepareLatentsStep(),
QwenImageRoPEInputsStep(),
QwenImageInpaintDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
# Qwen Image (image2image)
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImagePrepareLatentsWithStrengthStep(),
QwenImageRoPEInputsStep(),
QwenImageDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_img2img_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
# Qwen Image (text2image) with controlnet
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageTextInputsStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
# Qwen Image (inpainting) with controlnet
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageInpaintInputStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageInpaintPrepareLatentsStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageInpaintControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
# Qwen Image (image2image) with controlnet
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [
QwenImageImg2ImgInputStep(),
QwenImageControlNetInputsStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImagePrepareLatentsWithStrengthStep(),
QwenImageRoPEInputsStep(),
QwenImageControlNetBeforeDenoiserStep(),
QwenImageControlNetDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"controlnet_input",
"prepare_latents",
"set_timesteps",
"prepare_img2img_latents",
"prepare_rope_inputs",
"controlnet_before_denoise",
"controlnet_denoise",
"after_denoise",
]
@property
def description(self):
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
# Auto denoise step for QwenImage
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
block_classes = [
QwenImageCoreDenoiseStep,
QwenImageInpaintCoreDenoiseStep,
QwenImageImg2ImgCoreDenoiseStep,
QwenImageControlNetCoreDenoiseStep,
QwenImageControlNetInpaintCoreDenoiseStep,
QwenImageControlNetImg2ImgCoreDenoiseStep,
]
block_names = [
"text2image",
"inpaint",
"img2img",
"controlnet_text2image",
"controlnet_inpaint",
"controlnet_img2img",
]
block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
default_block_name = "text2image"
def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
if control_image_latents is not None:
if processed_mask_image is not None:
return "controlnet_inpaint"
elif image_latents is not None:
return "controlnet_img2img"
else:
return "controlnet_text2image"
else:
if processed_mask_image is not None:
return "inpaint"
elif image_latents is not None:
return "img2img"
else:
return "text2image"
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
+ " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
+ " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
+ " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
+ " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
+ " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
)
# ====================
# 3. DECODE
# ====================
# standard decode step works for most tasks except for inpaint
class QwenImageDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image."
# Inpaint decode step
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
# Auto decode step for QwenImage
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
block_names = ["inpaint_decode", "decode"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Decode step that decode the latents into images. \n"
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
)
# ====================
# 4. AUTO BLOCKS & PRESETS
# ====================
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageTextEncoderStep()),
("vae_encoder", QwenImageAutoVaeEncoderStep()),
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
("denoise", QwenImageAutoCoreDenoiseStep()),
("decode", QwenImageAutoDecodeStep()),
]
)
class QwenImageAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)

View File

@@ -0,0 +1,336 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageCreateMaskLatentsStep,
QwenImageEditRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImagePrepareLatentsWithStrengthStep,
QwenImageSetTimestepsStep,
QwenImageSetTimestepsWithStrengthStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageInpaintProcessImagesOutputStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageEditDenoiseStep,
QwenImageEditInpaintDenoiseStep,
)
from .encoders import (
QwenImageEditInpaintProcessImagesInputStep,
QwenImageEditProcessImagesInputStep,
QwenImageEditResizeStep,
QwenImageEditTextEncoderStep,
QwenImageVaeEncoderStep,
)
from .inputs import (
QwenImageAdditionalInputsStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. TEXT ENCODER
# ====================
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts."""
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeStep(),
QwenImageEditTextEncoderStep(),
]
block_names = ["resize", "encode"]
@property
def description(self) -> str:
return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
# ====================
# 2. VAE ENCODER
# ====================
# Edit VAE encoder
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeStep(),
QwenImageEditProcessImagesInputStep(),
QwenImageVaeEncoderStep(),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that encode the image inputs into their latent representations."
# Edit Inpaint VAE encoder
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditResizeStep(),
QwenImageEditInpaintProcessImagesInputStep(),
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return (
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
" - process the resized image and mask image.\n"
" - create image latents."
)
# Auto VAE encoder
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["mask_image", "image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
"This is an auto pipeline block.\n"
" - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
" - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
" - if `mask_image` or `image` is not provided, step will be skipped."
)
# ====================
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
# assemble input steps
class QwenImageEditInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageTextInputsStep(),
QwenImageAdditionalInputsStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
# assemble prepare latents steps
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
block_names = ["add_noise_to_latents", "create_mask_latents"]
@property
def description(self) -> str:
return (
"This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
" - Add noise to the image latents to create the latents input for the denoiser.\n"
" - Create the patchified latents `mask` based on the processed mask image.\n"
)
# Qwen Image Edit (image2image) core denoise step
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageEditRoPEInputsStep(),
QwenImageEditDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
# Qwen Image Edit (inpainting) core denoise step
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInpaintInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsWithStrengthStep(),
QwenImageEditInpaintPrepareLatentsStep(),
QwenImageEditRoPEInputsStep(),
QwenImageEditInpaintDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_inpaint_latents",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
# Auto core denoise step for QwenImage Edit
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [
QwenImageEditInpaintCoreDenoiseStep,
QwenImageEditCoreDenoiseStep,
]
block_names = ["edit_inpaint", "edit"]
block_trigger_inputs = ["processed_mask_image", "image_latents"]
default_block_name = "edit"
def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
if processed_mask_image is not None:
return "edit_inpaint"
elif image_latents is not None:
return "edit"
return None
@property
def description(self):
return (
"Auto core denoising step that selects the appropriate workflow based on inputs.\n"
" - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
" - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
)
# ====================
# 4. DECODE
# ====================
# Decode step (standard)
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image."
# Inpaint decode step
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
# Auto decode step
class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
block_names = ["inpaint_decode", "decode"]
block_trigger_inputs = ["mask", None]
@property
def description(self):
return (
"Decode step that decode the latents into images.\n"
"This is an auto pipeline block.\n"
" - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
)
# ====================
# 5. AUTO BLOCKS & PRESETS
# ====================
EDIT_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditVLEncoderStep()),
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
("denoise", QwenImageEditAutoCoreDenoiseStep()),
("decode", QwenImageEditAutoDecodeStep()),
]
)
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-edit"
block_classes = EDIT_AUTO_BLOCKS.values()
block_names = EDIT_AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
"- for edit (img2img) generation, you need to provide `image`\n"
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
)

View File

@@ -0,0 +1,181 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageEditPlusRoPEInputsStep,
QwenImagePrepareLatentsStep,
QwenImageSetTimestepsStep,
)
from .decoders import (
QwenImageAfterDenoiseStep,
QwenImageDecoderStep,
QwenImageProcessImagesOutputStep,
)
from .denoise import (
QwenImageEditDenoiseStep,
)
from .encoders import (
QwenImageEditPlusProcessImagesInputStep,
QwenImageEditPlusResizeStep,
QwenImageEditPlusTextEncoderStep,
QwenImageVaeEncoderStep,
)
from .inputs import (
QwenImageEditPlusAdditionalInputsStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. TEXT ENCODER
# ====================
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
QwenImageEditPlusTextEncoderStep(),
]
block_names = ["resize", "encode"]
@property
def description(self) -> str:
return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
# ====================
# 2. VAE ENCODER
# ====================
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
QwenImageEditPlusProcessImagesInputStep(),
QwenImageVaeEncoderStep(),
]
block_names = ["resize", "preprocess", "encode"]
@property
def description(self) -> str:
return (
"VAE encoder step that encodes image inputs into latent representations.\n"
"Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
)
# ====================
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
# assemble input steps
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageTextInputsStep(),
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
" - Standardizes text embeddings batch size.\n"
" - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
" - Outputs lists of image_height/image_width for RoPE calculation.\n"
" - Defaults height/width from last image in the list."
)
# Qwen Image Edit Plus (image2image) core denoise step
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [
QwenImageEditPlusInputStep(),
QwenImagePrepareLatentsStep(),
QwenImageSetTimestepsStep(),
QwenImageEditPlusRoPEInputsStep(),
QwenImageEditDenoiseStep(),
QwenImageAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
# ====================
# 4. DECODE
# ====================
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
block_names = ["decode", "postprocess"]
@property
def description(self):
return "Decode step that decodes the latents to images and postprocesses the generated image."
# ====================
# 5. AUTO BLOCKS & PRESETS
# ====================
EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageEditPlusVLEncoderStep()),
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
("denoise", QwenImageEditPlusCoreDenoiseStep()),
("decode", QwenImageEditPlusDecodeStep()),
]
)
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-edit-plus"
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
"- `image` is required input (can be single image or list of images).\n"
"- Each image is resized independently based on its own aspect ratio.\n"
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
)

View File

@@ -0,0 +1,159 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
QwenImageLayeredPrepareLatentsStep,
QwenImageLayeredRoPEInputsStep,
QwenImageLayeredSetTimestepsStep,
)
from .decoders import (
QwenImageLayeredAfterDenoiseStep,
QwenImageLayeredDecoderStep,
)
from .denoise import (
QwenImageLayeredDenoiseStep,
)
from .encoders import (
QwenImageEditProcessImagesInputStep,
QwenImageLayeredGetImagePromptStep,
QwenImageLayeredPermuteLatentsStep,
QwenImageLayeredResizeStep,
QwenImageTextEncoderStep,
QwenImageVaeEncoderStep,
)
from .inputs import (
QwenImageLayeredAdditionalInputsStep,
QwenImageTextInputsStep,
)
logger = logging.get_logger(__name__)
# ====================
# 1. TEXT ENCODER
# ====================
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
model_name = "qwenimage-layered"
block_classes = [
QwenImageLayeredResizeStep(),
QwenImageLayeredGetImagePromptStep(),
QwenImageTextEncoderStep(),
]
block_names = ["resize", "get_image_prompt", "encode"]
@property
def description(self) -> str:
return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided."
# ====================
# 2. VAE ENCODER
# ====================
# Edit VAE encoder
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
model_name = "qwenimage-layered"
block_classes = [
QwenImageLayeredResizeStep(),
QwenImageEditProcessImagesInputStep(),
QwenImageVaeEncoderStep(),
QwenImageLayeredPermuteLatentsStep(),
]
block_names = ["resize", "preprocess", "encode", "permute"]
@property
def description(self) -> str:
return "Vae encoder step that encode the image inputs into their latent representations."
# ====================
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
# ====================
# assemble input steps
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
model_name = "qwenimage-layered"
block_classes = [
QwenImageTextInputsStep(),
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the layered denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
# Qwen Image Layered (image2image) core denoise step
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "qwenimage-layered"
block_classes = [
QwenImageLayeredInputStep(),
QwenImageLayeredPrepareLatentsStep(),
QwenImageLayeredSetTimestepsStep(),
QwenImageLayeredRoPEInputsStep(),
QwenImageLayeredDenoiseStep(),
QwenImageLayeredAfterDenoiseStep(),
]
block_names = [
"input",
"prepare_latents",
"set_timesteps",
"prepare_rope_inputs",
"denoise",
"after_denoise",
]
@property
def description(self):
return "Core denoising workflow for QwenImage-Layered img2img task."
# ====================
# 4. AUTO BLOCKS & PRESETS
# ====================
LAYERED_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", QwenImageLayeredTextEncoderStep()),
("vae_encoder", QwenImageLayeredVaeEncoderStep()),
("denoise", QwenImageLayeredCoreDenoiseStep()),
("decode", QwenImageLayeredDecoderStep()),
]
)
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
model_name = "qwenimage-layered"
block_classes = LAYERED_AUTO_BLOCKS.values()
block_names = LAYERED_AUTO_BLOCKS.keys()
@property
def description(self):
return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."

View File

@@ -90,6 +90,88 @@ class QwenImagePachifier(ConfigMixin):
return latents
class QwenImageLayeredPachifier(ConfigMixin):
"""
A class to pack and unpack latents for QwenImage Layered.
Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W).
"""
config_name = "config.json"
@register_to_config
def __init__(self, patch_size: int = 2):
super().__init__()
def pack_latents(self, latents):
"""
Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4).
"""
if latents.ndim != 5:
raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}")
batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape
patch_size = self.config.patch_size
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
raise ValueError(
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
)
latents = latents.view(
batch_size,
layers,
num_channels_latents,
latent_height // patch_size,
patch_size,
latent_width // patch_size,
patch_size,
)
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
latents = latents.reshape(
batch_size,
layers * (latent_height // patch_size) * (latent_width // patch_size),
num_channels_latents * patch_size * patch_size,
)
return latents
def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8):
"""
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W).
"""
if latents.ndim != 3:
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
batch_size, _, channels = latents.shape
patch_size = self.config.patch_size
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
latents = latents.view(
batch_size,
layers + 1,
height // patch_size,
width // patch_size,
channels // (patch_size * patch_size),
patch_size,
patch_size,
)
latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
latents = latents.reshape(
batch_size,
layers + 1,
channels // (patch_size * patch_size),
height,
width,
)
latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
return latents
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
"""
A ModularPipeline for QwenImage.
@@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
"""
default_blocks_name = "QwenImageEditPlusAutoBlocks"
class QwenImageLayeredModularPipeline(QwenImageModularPipeline):
"""
A ModularPipeline for QwenImage-Layered.
> [!WARNING] > This is an experimental feature and is likely to change in the future.
"""
default_blocks_name = "QwenImageLayeredAutoBlocks"

View File

@@ -0,0 +1,121 @@
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Prompt templates for QwenImage pipelines.
This module centralizes all prompt templates used across different QwenImage pipeline variants:
- QwenImage (base): Text-only encoding for text-to-image generation
- QwenImage Edit: VL encoding with single image for image editing
- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing
- QwenImage Layered: Auto-captioning for image decomposition
"""
# ============================================
# QwenImage Base (text-only encoding)
# ============================================
# Used for text-to-image generation where only text prompt is encoded
QWENIMAGE_PROMPT_TEMPLATE = (
"<|im_start|>system\n"
"Describe the image by detailing the color, shape, size, texture, quantity, text, "
"spatial relationships of the objects and background:<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n"
"<|im_start|>assistant\n"
)
QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34
# ============================================
# QwenImage Edit (VL encoding with single image)
# ============================================
# Used for single-image editing where both image and text are encoded together
QWENIMAGE_EDIT_PROMPT_TEMPLATE = (
"<|im_start|>system\n"
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
"then explain how the user's text instruction should alter or modify the image. "
"Generate a new image that meets the user's requirements while maintaining consistency "
"with the original input where appropriate.<|im_end|>\n"
"<|im_start|>user\n"
"<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
"<|im_start|>assistant\n"
)
QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64
# ============================================
# QwenImage Edit Plus (VL encoding with multiple images)
# ============================================
# Used for multi-reference editing where multiple images and text are encoded together
# The img_template is used to format each image in the prompt
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = (
"<|im_start|>system\n"
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
"then explain how the user's text instruction should alter or modify the image. "
"Generate a new image that meets the user's requirements while maintaining consistency "
"with the original input where appropriate.<|im_end|>\n"
"<|im_start|>user\n{}<|im_end|>\n"
"<|im_start|>assistant\n"
)
QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64
# ============================================
# QwenImage Layered (auto-captioning)
# ============================================
# Used for image decomposition where the VL model generates a caption from the input image
# if no prompt is provided. These prompts instruct the model to describe the image in detail.
QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"# Image Annotator\n"
"You are a professional image annotator. Please write an image caption based on the input image:\n"
"1. Write the caption using natural, descriptive language without structured formats or rich text.\n"
"2. Enrich caption details by including:\n"
" - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n"
" - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, "
"attachment relations, action relations, comparative relations, causal relations, and so on\n"
" - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n"
" - Identify the text clearly visible in the image, without translation or explanation, "
"and highlight it in the caption with quotation marks\n"
"3. Maintain authenticity and accuracy:\n"
" - Avoid generalizations\n"
" - Describe all visible information in the image, while do not add information not explicitly shown in the image\n"
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
"<|im_start|>assistant\n"
)
QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = (
"<|im_start|>system\n"
"You are a helpful assistant.<|im_end|>\n"
"<|im_start|>user\n"
"# 图像标注器\n"
"你是一个专业的图像标注器。请基于输入图像,撰写图注:\n"
"1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n"
"2. 通过加入以下内容,丰富图注细节:\n"
" - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n"
" - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n"
" - 环境细节:例如天气、光照、颜色、纹理、气氛等\n"
" - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n"
"3. 保持真实性与准确性:\n"
" - 不要使用笼统的描述\n"
" - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n"
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
"<|im_start|>assistant\n"
)

View File

@@ -288,7 +288,9 @@ else:
"LTXImageToVideoPipeline",
"LTXConditionPipeline",
"LTXLatentUpsamplePipeline",
"LTXI2VLongMultiPromptPipeline",
]
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
_import_structure["lucy"] = ["LucyEditPipeline"]
@@ -729,7 +731,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
)
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
from .ltx import (
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline

View File

@@ -99,6 +99,7 @@ from .qwenimage import (
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImageLayeredPipeline,
QwenImagePipeline,
)
from .sana import SanaPipeline
@@ -202,6 +203,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
("qwenimage", QwenImageImg2ImgPipeline),
("qwenimage-edit", QwenImageEditPipeline),
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
("qwenimage-layered", QwenImageLayeredPipeline),
("z-image", ZImageImg2ImgPipeline),
]
)

View File

@@ -25,6 +25,7 @@ else:
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
_import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
from .pipeline_ltx import LTXPipeline
from .pipeline_ltx_condition import LTXConditionPipeline
from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,58 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,325 @@
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.attention import FeedForward
from ...models.modeling_utils import ModelMixin
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
class LTX2RotaryPosEmbed1d(nn.Module):
"""
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
"""
def __init__(
self,
dim: int,
base_seq_len: int = 4096,
theta: float = 10000.0,
double_precision: bool = True,
rope_type: str = "interleaved",
num_attention_heads: int = 32,
):
super().__init__()
if rope_type not in ["interleaved", "split"]:
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
self.dim = dim
self.base_seq_len = base_seq_len
self.theta = theta
self.double_precision = double_precision
self.rope_type = rope_type
self.num_attention_heads = num_attention_heads
def forward(
self,
batch_size: int,
pos: int,
device: Union[str, torch.device],
) -> Tuple[torch.Tensor, torch.Tensor]:
# 1. Get 1D position ids
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
# Get fractional indices relative to self.base_seq_len
grid_1d = grid_1d / self.base_seq_len
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
# 2. Calculate 1D RoPE frequencies
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
freqs_dtype = torch.float64 if self.double_precision else torch.float32
pow_indices = torch.pow(
self.theta,
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
)
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
# (self.dim // 2,).
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
if self.rope_type == "interleaved":
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
if self.dim % num_rope_elems != 0:
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
elif self.rope_type == "split":
expected_freqs = self.dim // 2
current_freqs = freqs.shape[-1]
pad_size = expected_freqs - current_freqs
cos_freq = freqs.cos()
sin_freq = freqs.sin()
if pad_size != 0:
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
# Reshape freqs to be compatible with multi-head attention
b = cos_freq.shape[0]
t = cos_freq.shape[1]
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
return cos_freqs, sin_freqs
class LTX2TransformerBlock1d(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
activation_fn: str = "gelu-approximate",
eps: float = 1e-6,
rope_type: str = "interleaved",
):
super().__init__()
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.attn1 = LTX2Attention(
query_dim=dim,
heads=num_attention_heads,
kv_heads=num_attention_heads,
dim_head=attention_head_dim,
processor=LTX2AudioVideoAttnProcessor(),
rope_type=rope_type,
)
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
self.ff = FeedForward(dim, activation_fn=activation_fn)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
norm_hidden_states = self.norm1(hidden_states)
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
hidden_states = hidden_states + attn_hidden_states
norm_hidden_states = self.norm2(hidden_states)
ff_hidden_states = self.ff(norm_hidden_states)
hidden_states = hidden_states + ff_hidden_states
return hidden_states
class LTX2ConnectorTransformer1d(nn.Module):
"""
A 1D sequence transformer for modalities such as text.
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
"""
_supports_gradient_checkpointing = True
def __init__(
self,
num_attention_heads: int = 30,
attention_head_dim: int = 128,
num_layers: int = 2,
num_learnable_registers: int | None = 128,
rope_base_seq_len: int = 4096,
rope_theta: float = 10000.0,
rope_double_precision: bool = True,
eps: float = 1e-6,
causal_temporal_positioning: bool = False,
rope_type: str = "interleaved",
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.inner_dim = num_attention_heads * attention_head_dim
self.causal_temporal_positioning = causal_temporal_positioning
self.num_learnable_registers = num_learnable_registers
self.learnable_registers = None
if num_learnable_registers is not None:
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
self.learnable_registers = torch.nn.Parameter(init_registers)
self.rope = LTX2RotaryPosEmbed1d(
self.inner_dim,
base_seq_len=rope_base_seq_len,
theta=rope_theta,
double_precision=rope_double_precision,
rope_type=rope_type,
num_attention_heads=num_attention_heads,
)
self.transformer_blocks = torch.nn.ModuleList(
[
LTX2TransformerBlock1d(
dim=self.inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
rope_type=rope_type,
)
for _ in range(num_layers)
]
)
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attn_mask_binarize_threshold: float = -9000.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
# hidden_states shape: [batch_size, seq_len, hidden_dim]
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
batch_size, seq_len, _ = hidden_states.shape
# 1. Replace padding with learned registers, if using
if self.learnable_registers is not None:
if seq_len % self.num_learnable_registers != 0:
raise ValueError(
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
f" of learnable registers {self.num_learnable_registers}"
)
num_register_repeats = seq_len // self.num_learnable_registers
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
if binary_attn_mask.ndim == 4:
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
padded_hidden_states = [
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
]
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
# Overwrite attention_mask with an all-zeros mask if using registers.
attention_mask = torch.zeros_like(attention_mask)
# 2. Calculate 1D RoPE positional embeddings
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
# 3. Run 1D transformer blocks
for block in self.transformer_blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
else:
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
hidden_states = self.norm_out(hidden_states)
return hidden_states, attention_mask
class LTX2TextConnectors(ModelMixin, ConfigMixin):
"""
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
streams.
"""
@register_to_config
def __init__(
self,
caption_channels: int,
text_proj_in_factor: int,
video_connector_num_attention_heads: int,
video_connector_attention_head_dim: int,
video_connector_num_layers: int,
video_connector_num_learnable_registers: int | None,
audio_connector_num_attention_heads: int,
audio_connector_attention_head_dim: int,
audio_connector_num_layers: int,
audio_connector_num_learnable_registers: int | None,
connector_rope_base_seq_len: int,
rope_theta: float,
rope_double_precision: bool,
causal_temporal_positioning: bool,
rope_type: str = "interleaved",
):
super().__init__()
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
self.video_connector = LTX2ConnectorTransformer1d(
num_attention_heads=video_connector_num_attention_heads,
attention_head_dim=video_connector_attention_head_dim,
num_layers=video_connector_num_layers,
num_learnable_registers=video_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
self.audio_connector = LTX2ConnectorTransformer1d(
num_attention_heads=audio_connector_num_attention_heads,
attention_head_dim=audio_connector_attention_head_dim,
num_layers=audio_connector_num_layers,
num_learnable_registers=audio_connector_num_learnable_registers,
rope_base_seq_len=connector_rope_base_seq_len,
rope_theta=rope_theta,
rope_double_precision=rope_double_precision,
causal_temporal_positioning=causal_temporal_positioning,
rope_type=rope_type,
)
def forward(
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
):
# Convert to additive attention mask, if necessary
if not additive_mask:
text_dtype = text_encoder_hidden_states.dtype
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
video_text_embedding = video_text_embedding * attn_mask
new_attn_mask = attn_mask.squeeze(-1)
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
return video_text_embedding, audio_text_embedding, new_attn_mask

View File

@@ -0,0 +1,134 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fractions import Fraction
from typing import Optional
import torch
from ...utils import is_av_available
_CAN_USE_AV = is_av_available()
if _CAN_USE_AV:
import av
else:
raise ImportError(
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
)
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
"""
Prepare the audio stream for writing.
"""
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
audio_stream.codec_context.sample_rate = audio_sample_rate
audio_stream.codec_context.layout = "stereo"
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
return audio_stream
def _resample_audio(
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
) -> None:
cc = audio_stream.codec_context
# Use the encoder's format/layout/rate as the *target*
target_format = cc.format or "fltp" # AAC → usually fltp
target_layout = cc.layout or "stereo"
target_rate = cc.sample_rate or frame_in.sample_rate
audio_resampler = av.audio.resampler.AudioResampler(
format=target_format,
layout=target_layout,
rate=target_rate,
)
audio_next_pts = 0
for rframe in audio_resampler.resample(frame_in):
if rframe.pts is None:
rframe.pts = audio_next_pts
audio_next_pts += rframe.samples
rframe.sample_rate = frame_in.sample_rate
container.mux(audio_stream.encode(rframe))
# flush audio encoder
for packet in audio_stream.encode():
container.mux(packet)
def _write_audio(
container: av.container.Container,
audio_stream: av.audio.AudioStream,
samples: torch.Tensor,
audio_sample_rate: int,
) -> None:
if samples.ndim == 1:
samples = samples[:, None]
if samples.shape[1] != 2 and samples.shape[0] == 2:
samples = samples.T
if samples.shape[1] != 2:
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
if samples.dtype != torch.int16:
samples = torch.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).to(torch.int16)
frame_in = av.AudioFrame.from_ndarray(
samples.contiguous().reshape(1, -1).cpu().numpy(),
format="s16",
layout="stereo",
)
frame_in.sample_rate = audio_sample_rate
_resample_audio(container, audio_stream, frame_in)
def encode_video(
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
) -> None:
video_np = video.cpu().numpy()
_, height, width, _ = video_np.shape
container = av.open(output_path, mode="w")
stream = container.add_stream("libx264", rate=int(fps))
stream.width = width
stream.height = height
stream.pix_fmt = "yuv420p"
if audio is not None:
if audio_sample_rate is None:
raise ValueError("audio_sample_rate is required when audio is provided")
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
for frame_array in video_np:
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
for packet in stream.encode(frame):
container.mux(packet)
# Flush encoder
for packet in stream.encode():
container.mux(packet)
if audio is not None:
_write_audio(container, audio_stream, audio, audio_sample_rate)
container.close()

View File

@@ -0,0 +1,285 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
import torch
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
RATIONAL_RESAMPLER_SCALE_MAPPING = {
0.75: (3, 4),
1.5: (3, 2),
2.0: (2, 1),
4.0: (4, 1),
}
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
class ResBlock(torch.nn.Module):
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
super().__init__()
if mid_channels is None:
mid_channels = channels
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
self.norm2 = torch.nn.GroupNorm(32, channels)
self.activation = torch.nn.SiLU()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
residual = hidden_states
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm1(hidden_states)
hidden_states = self.activation(hidden_states)
hidden_states = self.conv2(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.activation(hidden_states + residual)
return hidden_states
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
class PixelShuffleND(torch.nn.Module):
def __init__(self, dims, upscale_factors=(2, 2, 2)):
super().__init__()
self.dims = dims
self.upscale_factors = upscale_factors
if dims not in [1, 2, 3]:
raise ValueError("dims must be 1, 2, or 3")
def forward(self, x):
if self.dims == 3:
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
return (
x.unflatten(1, (-1, *self.upscale_factors[:3]))
.permute(0, 1, 5, 2, 6, 3, 7, 4)
.flatten(6, 7)
.flatten(4, 5)
.flatten(2, 3)
)
elif self.dims == 2:
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
return (
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
)
elif self.dims == 1:
# temporal: b (c p1) f h w -> b c (f p1) h w
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
class BlurDownsample(torch.nn.Module):
"""
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
Works for dims=2 or dims=3 (per-frame).
"""
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
super().__init__()
if dims not in (2, 3):
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
if kernel_size < 3 or kernel_size % 2 != 1:
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
self.dims = dims
self.stride = stride
self.kernel_size = kernel_size
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
# The 2D kernel is constructed as the outer product and normalized.
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
k2d = k[:, None] @ k[None, :]
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.stride == 1:
return x
if self.dims == 2:
c = x.shape[1]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
else:
# dims == 3: apply per-frame on H,W
b, c, f, _, _ = x.shape
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
h2, w2 = x.shape[-2:]
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class SpatialRationalResampler(torch.nn.Module):
"""
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
(integer) denominator.
"""
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
super().__init__()
self.scale = float(scale)
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
if num_denom is None:
raise ValueError(
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
)
self.num, self.den = num_denom
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
self.blur_down = BlurDownsample(dims=2, stride=self.den)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Expected x shape: [B * F, C, H, W]
# b, _, f, h, w = x.shape
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
x = self.conv(x)
x = self.pixel_shuffle(x)
x = self.blur_down(x)
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
return x
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
"""
Model to spatially upsample VAE latents.
Args:
in_channels (`int`, defaults to `128`):
Number of channels in the input latent
mid_channels (`int`, defaults to `512`):
Number of channels in the middle layers
num_blocks_per_stage (`int`, defaults to `4`):
Number of ResBlocks to use in each stage (pre/post upsampling)
dims (`int`, defaults to `3`):
Number of dimensions for convolutions (2 or 3)
spatial_upsample (`bool`, defaults to `True`):
Whether to spatially upsample the latent
temporal_upsample (`bool`, defaults to `False`):
Whether to temporally upsample the latent
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
mid_channels: int = 1024,
num_blocks_per_stage: int = 4,
dims: int = 3,
spatial_upsample: bool = True,
temporal_upsample: bool = False,
rational_spatial_scale: Optional[float] = 2.0,
):
super().__init__()
self.in_channels = in_channels
self.mid_channels = mid_channels
self.num_blocks_per_stage = num_blocks_per_stage
self.dims = dims
self.spatial_upsample = spatial_upsample
self.temporal_upsample = temporal_upsample
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
self.initial_activation = torch.nn.SiLU()
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
if spatial_upsample and temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(3),
)
elif spatial_upsample:
if rational_spatial_scale is not None:
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
else:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(2),
)
elif temporal_upsample:
self.upsampler = torch.nn.Sequential(
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
PixelShuffleND(1),
)
else:
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
self.post_upsample_res_blocks = torch.nn.ModuleList(
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
)
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
if self.dims == 2:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.upsampler(hidden_states)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
else:
hidden_states = self.initial_conv(hidden_states)
hidden_states = self.initial_norm(hidden_states)
hidden_states = self.initial_activation(hidden_states)
for block in self.res_blocks:
hidden_states = block(hidden_states)
if self.temporal_upsample:
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states[:, :, 1:, :, :]
else:
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.upsampler(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
for block in self.post_upsample_res_blocks:
hidden_states = block(hidden_states)
hidden_states = self.final_conv(hidden_states)
return hidden_states

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,442 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
import torch
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLLTX2Video
from ...utils import get_logger, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..ltx.pipeline_output import LTXPipelineOutput
from ..pipeline_utils import DiffusionPipeline
from .latent_upsampler import LTX2LatentUpsamplerModel
logger = get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
>>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
>>> from diffusers.utils import load_image
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
>>> pipe.enable_model_cpu_offload()
>>> image = load_image(
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
... )
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
>>> frame_rate = 24.0
>>> video, audio = pipe(
... image=image,
... prompt=prompt,
... negative_prompt=negative_prompt,
... width=768,
... height=512,
... num_frames=121,
... frame_rate=frame_rate,
... num_inference_steps=40,
... guidance_scale=4.0,
... output_type="pil",
... return_dict=False,
... )
>>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
... )
>>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
>>> upsample_pipe.vae.enable_tiling()
>>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
>>> video = upsample_pipe(
... video=video,
... width=768,
... height=512,
... output_type="np",
... return_dict=False,
... )[0]
>>> video = (video * 255).round().astype("uint8")
>>> video = torch.from_numpy(video)
>>> encode_video(
... video[0],
... fps=frame_rate,
... audio=audio[0].float().cpu(),
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
... output_path="video.mp4",
... )
```
"""
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
model_cpu_offload_seq = "vae->latent_upsampler"
def __init__(
self,
vae: AutoencoderKLLTX2Video,
latent_upsampler: LTX2LatentUpsamplerModel,
) -> None:
super().__init__()
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
self.vae_spatial_compression_ratio = (
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
)
self.vae_temporal_compression_ratio = (
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
def prepare_latents(
self,
video: Optional[torch.Tensor] = None,
batch_size: int = 1,
num_frames: int = 121,
height: int = 512,
width: int = 768,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if latents is not None:
if latents.ndim == 3:
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
latent_height = height // self.vae_spatial_compression_ratio
latent_width = width // self.vae_spatial_compression_ratio
latents = self._unpack_latents(
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
)
return latents.to(device=device, dtype=dtype)
video = video.to(device=device, dtype=self.vae.dtype)
if isinstance(generator, list):
if len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
init_latents = [
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
]
else:
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
init_latents = torch.cat(init_latents, dim=0).to(dtype)
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
return init_latents
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
"""
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
tensor.
Args:
latent (`torch.Tensor`):
Input latents to normalize
reference_latents (`torch.Tensor`):
The reference latents providing style statistics.
factor (`float`):
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
Returns:
torch.Tensor: The transformed latent tensor
"""
result = latents.clone()
for i in range(latents.size(0)):
for c in range(latents.size(1)):
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
result = torch.lerp(latents, result, factor)
return result
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
"""
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
smooth way using a sigmoid-based compression.
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
when controlling dynamic behavior with a `compression` factor.
Args:
latents : torch.Tensor
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
compression : float
Compression strength in the range [0, 1].
- 0.0: No tone-mapping (identity transform)
- 1.0: Full compression effect
Returns:
torch.Tensor
The tone-mapped latent tensor of the same shape as input.
"""
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
scale_factor = compression * 0.75
abs_latents = torch.abs(latents)
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
filtered = latents * scales
return filtered
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
def _normalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Normalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = (latents - latents_mean) * scaling_factor / latents_std
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
def _denormalize_latents(
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
) -> torch.Tensor:
# Denormalize latents across the channel dimension [B, C, F, H, W]
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
latents = latents * latents_std / scaling_factor + latents_mean
return latents
@staticmethod
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
def _unpack_latents(
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
) -> torch.Tensor:
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
# what happens in the `_pack_latents` method.
batch_size = latents.size(0)
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
return latents
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
if video is not None and latents is not None:
raise ValueError("Only one of `video` or `latents` can be provided.")
if video is None and latents is None:
raise ValueError("One of `video` or `latents` has to be provided.")
if not (0 <= tone_map_compression_ratio <= 1):
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
video: Optional[List[PipelineImageInput]] = None,
height: int = 512,
width: int = 768,
num_frames: int = 121,
spatial_patch_size: int = 1,
temporal_patch_size: int = 1,
latents: Optional[torch.Tensor] = None,
latents_normalized: bool = False,
decode_timestep: Union[float, List[float]] = 0.0,
decode_noise_scale: Optional[Union[float, List[float]]] = None,
adain_factor: float = 0.0,
tone_map_compression_ratio: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
video (`List[PipelineImageInput]`, *optional*)
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
supplied.
height (`int`, *optional*, defaults to `512`):
The height in pixels of the input video (not the generated video, which will have a larger resolution).
width (`int`, *optional*, defaults to `768`):
The width in pixels of the input video (not the generated video, which will have a larger resolution).
num_frames (`int`, *optional*, defaults to `121`):
The number of frames in the input video.
spatial_patch_size (`int`, *optional*, defaults to `1`):
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
temporal_patch_size (`int`, *optional*, defaults to `1`):
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
necessary.
latents (`torch.Tensor`, *optional*):
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
latent_channels, latent_frames, latent_height, latent_width)`.
latents_normalized (`bool`, *optional*, defaults to `False`)
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
decode_timestep (`float`, defaults to `0.0`):
The timestep at which generated video is decoded.
decode_noise_scale (`float`, defaults to `None`):
The interpolation factor between random noise and denoised latents at the decode timestep.
adain_factor (`float`, *optional*, defaults to `0.0`):
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
the full compression effect.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
Examples:
Returns:
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
returned where the first element is the upsampled video.
"""
self.check_inputs(
video=video,
height=height,
width=width,
latents=latents,
tone_map_compression_ratio=tone_map_compression_ratio,
)
if video is not None:
# Batched video input is not yet tested/supported. TODO: take a look later
batch_size = 1
else:
batch_size = latents.shape[0]
device = self._execution_device
if video is not None:
num_frames = len(video)
if num_frames % self.vae_temporal_compression_ratio != 1:
num_frames = (
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
)
video = video[:num_frames]
logger.warning(
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
)
video = self.video_processor.preprocess_video(video, height=height, width=width)
video = video.to(device=device, dtype=torch.float32)
latents_supplied = latents is not None
latents = self.prepare_latents(
video=video,
batch_size=batch_size,
num_frames=num_frames,
height=height,
width=width,
spatial_patch_size=spatial_patch_size,
temporal_patch_size=temporal_patch_size,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
if latents_supplied and latents_normalized:
latents = self._denormalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
latents = latents.to(self.latent_upsampler.dtype)
latents_upsampled = self.latent_upsampler(latents)
if adain_factor > 0.0:
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
else:
latents = latents_upsampled
if tone_map_compression_ratio > 0.0:
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
if output_type == "latent":
latents = self._normalize_latents(
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
)
video = latents
else:
if not self.vae.config.timestep_conditioning:
timestep = None
else:
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
if not isinstance(decode_timestep, list):
decode_timestep = [decode_timestep] * batch_size
if decode_noise_scale is None:
decode_noise_scale = decode_timestep
elif not isinstance(decode_noise_scale, list):
decode_noise_scale = [decode_noise_scale] * batch_size
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
:, None, None, None, None
]
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
video = self.vae.decode(latents, timestep, return_dict=False)[0]
video = self.video_processor.postprocess_video(video, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return LTXPipelineOutput(frames=video)

View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class LTX2PipelineOutput(BaseOutput):
r"""
Output class for LTX pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
audio (`torch.Tensor`, `np.ndarray`):
TODO
"""
frames: torch.Tensor
audio: torch.Tensor

View File

@@ -0,0 +1,159 @@
import math
from typing import List, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...models.modeling_utils import ModelMixin
class ResBlock(nn.Module):
def __init__(
self,
channels: int,
kernel_size: int = 3,
stride: int = 1,
dilations: Tuple[int, ...] = (1, 3, 5),
leaky_relu_negative_slope: float = 0.1,
padding_mode: str = "same",
):
super().__init__()
self.dilations = dilations
self.negative_slope = leaky_relu_negative_slope
self.convs1 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
for dilation in dilations
]
)
self.convs2 = nn.ModuleList(
[
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
for _ in range(len(dilations))
]
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
for conv1, conv2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
xt = conv1(xt)
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
xt = conv2(xt)
x = x + xt
return x
class LTX2Vocoder(ModelMixin, ConfigMixin):
r"""
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
"""
@register_to_config
def __init__(
self,
in_channels: int = 128,
hidden_channels: int = 1024,
out_channels: int = 2,
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
upsample_factors: List[int] = [6, 5, 2, 2, 2],
resnet_kernel_sizes: List[int] = [3, 7, 11],
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
leaky_relu_negative_slope: float = 0.1,
output_sampling_rate: int = 24000,
):
super().__init__()
self.num_upsample_layers = len(upsample_kernel_sizes)
self.resnets_per_upsample = len(resnet_kernel_sizes)
self.out_channels = out_channels
self.total_upsample_factor = math.prod(upsample_factors)
self.negative_slope = leaky_relu_negative_slope
if self.num_upsample_layers != len(upsample_factors):
raise ValueError(
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
)
if self.resnets_per_upsample != len(resnet_dilations):
raise ValueError(
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
)
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
self.upsamplers = nn.ModuleList()
self.resnets = nn.ModuleList()
input_channels = hidden_channels
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
output_channels = input_channels // 2
self.upsamplers.append(
nn.ConvTranspose1d(
input_channels, # hidden_channels // (2 ** i)
output_channels, # hidden_channels // (2 ** (i + 1))
kernel_size,
stride=stride,
padding=(kernel_size - stride) // 2,
)
)
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
self.resnets.append(
ResBlock(
output_channels,
kernel_size,
dilations=dilations,
leaky_relu_negative_slope=leaky_relu_negative_slope,
)
)
input_channels = output_channels
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
r"""
Forward pass of the vocoder.
Args:
hidden_states (`torch.Tensor`):
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
`True`.
time_last (`bool`, *optional*, defaults to `False`):
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
Returns:
`torch.Tensor`:
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
"""
# Ensure that the time/frame dimension is last
if not time_last:
hidden_states = hidden_states.transpose(2, 3)
# Combine channels and frequency (mel bins) dimensions
hidden_states = hidden_states.flatten(1, 2)
hidden_states = self.conv_in(hidden_states)
for i in range(self.num_upsample_layers):
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
hidden_states = self.upsamplers[i](hidden_states)
# Run all resnets in parallel on hidden_states
start = i * self.resnets_per_upsample
end = (i + 1) * self.resnets_per_upsample
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
hidden_states = torch.mean(resnet_outputs, dim=0)
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
hidden_states = self.conv_out(hidden_states)
hidden_states = torch.tanh(hidden_states)
return hidden_states

View File

@@ -801,12 +801,6 @@ def load_sub_model(
# add kwargs to loading method
diffusers_module = importlib.import_module(__name__.split(".")[0])
loading_kwargs = {}
if issubclass(class_obj, torch.nn.Module):
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
loading_kwargs["provider_options"] = provider_options
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
@@ -821,6 +815,17 @@ def load_sub_model(
and transformers_version >= version.parse("4.20.0")
)
# For transformers models >= 4.56.0, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings
if issubclass(class_obj, torch.nn.Module):
if is_transformers_model and transformers_version >= version.parse("4.56.0"):
loading_kwargs["dtype"] = torch_dtype
else:
loading_kwargs["torch_dtype"] = torch_dtype
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
loading_kwargs["provider"] = provider
loading_kwargs["sess_options"] = sess_options
loading_kwargs["provider_options"] = provider_options
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
# This makes sure that the weights won't be initialized which significantly speeds up loading.

View File

@@ -67,6 +67,7 @@ from ..utils import (
logging,
numpy_to_pil,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
@@ -982,7 +983,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# 7. Load each module in the pipeline
current_device_map = None
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
if not is_torch_dist_rank_zero():
logging_tqdm_kwargs["disable"] = True
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
# 7.1 device_map shenanigans
if final_device_map is not None:
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
@@ -1908,10 +1913,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
)
progress_bar_config = dict(self._progress_bar_config)
if "disable" not in progress_bar_config:
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
if iterable is not None:
return tqdm(iterable, **self._progress_bar_config)
return tqdm(iterable, **progress_bar_config)
elif total is not None:
return tqdm(total=total, **self._progress_bar_config)
return tqdm(total=total, **progress_bar_config)
else:
raise ValueError("Either `total` or `iterable` has to be defined.")

View File

@@ -545,22 +545,24 @@ class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
latent_model_input = latents.to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -887,25 +887,28 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoader
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]

View File

@@ -966,25 +966,28 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReels
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]

View File

@@ -974,25 +974,28 @@ class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReels
)
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
enable_diffusion_forcing=True,
fps=fps_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
update_mask_i = step_update_mask[i]

View File

@@ -678,24 +678,26 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
timestep = t.expand(latents.shape[0])
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
noise_uncond = self.transformer(
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
encoder_hidden_states_image=image_embeds,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
# compute the previous noisy sample x_t -> x_t-1

View File

@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
`float8_e4m3_tensor`, `float8_e4m3_row`,
- **Floating point X-bit quantization:**
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
- Full function names: `fpx_weight_only`
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
@@ -531,12 +531,18 @@ class TorchAoConfig(QuantizationConfigMixin):
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
is_floatx_quant_type = self.quant_type.startswith("fp")
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
)
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
raise ValueError(
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
)
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
@@ -622,7 +628,6 @@ class TorchAoConfig(QuantizationConfigMixin):
float8_dynamic_activation_float8_weight,
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
int4_weight_only,
int8_dynamic_activation_int4_weight,
int8_dynamic_activation_int8_weight,
@@ -630,6 +635,8 @@ class TorchAoConfig(QuantizationConfigMixin):
uintx_weight_only,
)
if is_torchao_version("<=", "0.14.1"):
from torchao.quantization import fpx_weight_only
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
from torchao.quantization.observer import PerRow, PerTensor
@@ -650,18 +657,21 @@ class TorchAoConfig(QuantizationConfigMixin):
return types
def generate_fpx_quantization_types(bits: int):
types = {}
if is_torchao_version("<=", "0.14.1"):
types = {}
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
for ebits in range(1, bits):
mbits = bits - ebits - 1
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
non_sign_bits = bits - 1
default_ebits = (non_sign_bits + 1) // 2
default_mbits = non_sign_bits - default_ebits
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
return types
return types
else:
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
INT4_QUANTIZATION_TYPES = {
# int4 weight + bfloat16/float16 activation
@@ -710,15 +720,15 @@ class TorchAoConfig(QuantizationConfigMixin):
**generate_float8dq_types(torch.float8_e4m3fn),
# float8 weight + float8 activation (static)
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
# fpx weight + bfloat16/float16 activation
**generate_fpx_quantization_types(3),
**generate_fpx_quantization_types(4),
**generate_fpx_quantization_types(5),
**generate_fpx_quantization_types(6),
**generate_fpx_quantization_types(7),
}
if is_torchao_version("<=", "0.14.1"):
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
UINTX_QUANTIZATION_DTYPES = {
"uintx_weight_only": uintx_weight_only,
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),

View File

@@ -66,6 +66,7 @@ else:
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
_import_structure["scheduling_ltx_euler_ancestral_rf"] = ["LTXEulerAncestralRFScheduler"]
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
@@ -168,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
from .scheduling_lcm import LCMScheduler
from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler
from .scheduling_pndm import PNDMScheduler
from .scheduling_repaint import RePaintScheduler
from .scheduling_sasolver import SASolverScheduler

View File

@@ -71,6 +71,22 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
"""
A scheduler for the consistency decoder used in Stable Diffusion pipelines.
This scheduler implements a two-step denoising process using consistency models for decoding latent representations
into images.
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
methods the library implements for all schedulers such as loading and saving.
Args:
num_train_timesteps (`int`, *optional*, defaults to `1024`):
The number of diffusion steps to train the model.
sigma_data (`float`, *optional*, defaults to `0.5`):
The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
"""
order = 1
@register_to_config
@@ -78,7 +94,7 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
self,
num_train_timesteps: int = 1024,
sigma_data: float = 0.5,
):
) -> None:
betas = betas_for_alpha_bar(num_train_timesteps)
alphas = 1.0 - betas
@@ -98,8 +114,18 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
def set_timesteps(
self,
num_inference_steps: Optional[int] = None,
device: Union[str, torch.device] = None,
):
device: Optional[Union[str, torch.device]] = None,
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
Args:
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
`2` inference steps are supported.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
"""
if num_inference_steps != 2:
raise ValueError("Currently more than 2 inference steps are not supported.")
@@ -111,7 +137,15 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
self.c_in = self.c_in.to(device)
@property
def init_noise_sigma(self):
def init_noise_sigma(self) -> torch.Tensor:
"""
Return the standard deviation of the initial noise distribution.
Returns:
`torch.Tensor`:
The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
timestep.
"""
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
@@ -146,20 +180,20 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`float`):
timestep (`float` or `torch.Tensor`):
The current timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
A random number generator for reproducibility.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`.
Returns:
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If return_dict is `True`,
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`:
If `return_dict` is `True`,
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
a tuple is returned where the first element is the sample tensor.
"""
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample

Some files were not shown because too many files have changed in this diff Show More