Compare commits

...

38 Commits

Author SHA1 Message Date
Sayak Paul
ed2ef83067 Apply suggestions from code review
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-04-03 08:20:48 +02:00
sayakpaul
1a3ffddc63 make regional compilation note clearer. 2026-04-03 08:10:16 +02:00
Sayak Paul
6d8e371061 Merge branch 'main' into profiling-workflow 2026-04-03 11:31:40 +05:30
Sayak Paul
5adc544b79 [tests] refactor wan autoencoder tests (#13371)
* refactor wan autoencoder tests

* up

* address dhruv's feedback.
2026-04-03 07:36:40 +02:00
jiqing-feng
a05c8e9452 Fix Dynamo lru_cache warnings during torch.compile (#13384)
* fix compile issue

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* compile friendly

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* add comments

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-04-03 07:12:54 +02:00
Zamuldinov Nikita
8070f6ec54 fix(ddim): validate eta is in [0, 1] in DDIMPipeline (#13367)
* fix(ddim): validate eta is in [0, 1] in DDIMPipeline.__call__

The DDIM paper defines η (eta) as a value that must lie in [0, 1]:
η=0 corresponds to deterministic DDIM, η=1 corresponds to DDPM.
The docstring already documented this constraint, but no runtime
validation was in place, so users could silently pass out-of-range
values (e.g. negative or >1) without any error.

Add an explicit ValueError check before the denoising loop so that
invalid eta values are caught early with a clear message.

Fixes #13362

Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com>

* fix(ddim): downgrade eta out-of-range from error to warning

Per maintainer feedback from @yiyixuxu — the documentation is
sufficient; a hard ValueError is too strict. Replace with a
UserWarning so callers are informed without breaking existing code
that passes eta outside [0, 1].

Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com>

* fix(ddim): use logger.warning instead of warnings.warn for eta validation

Address review request from @yiyixuxu: switch from warnings.warn() to
logger.warning() to be consistent with all other diffusers pipelines.

The eta validation check itself (0.0 <= eta <= 1.0) is unchanged.

Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com>

---------

Signed-off-by: NIK-TIGER-BILL <nik.tiger.bill@github.com>
Co-authored-by: NIK-TIGER-BILL <nik.tiger.bill@github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-04-02 16:07:28 -10:00
Samuel Meddin
3e53a383e1 Fix typos and grammar errors in documentation (#13391)
- Fix 'allows to generate' -> 'allows you to generate' in controlling_generation.md
- Fix 'it's refiner' -> 'its refiner' (possessive) in sdxl.md
- Fix 'it's state' -> 'its state' (possessive) in reusing_seeds.md
- Fix missing word 'you'll a function' -> 'you'll create a function' in sdxl.md
2026-04-02 13:42:32 -07:00
YiYi Xu
cf6af6b4f8 [docs] add auto docstring and parameter templates documentation for m… (#13382)
* [docs] add auto docstring and parameter templates documentation for modular diffusers

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/modular_diffusers/auto_docstring.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/_toctree.yml

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* up

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-04-02 10:34:45 -10:00
Pauline Bailly-Masson
3211cd9df0 🔒 Pin GitHub Actions to commit SHAs (#13385)
* 🔒 pin benchmark.yml actions to commit SHAs

* 🔒 pin nightly_tests.yml actions to commit SHAs

* 🔒 pin build_pr_documentation.yml actions to commit SHAs

* 🔒 pin typos.yml actions to commit SHAs

* 🔒 pin build_docker_images.yml actions to commit SHAs

* 🔒 pin build_documentation.yml actions to commit SHAs

* 🔒 pin upload_pr_documentation.yml actions to commit SHAs

* 🔒 pin pr_style_bot.yml actions to commit SHAs

* 🔒 pin codeql.yml actions to commit SHAs

* 🔒 pin ssh-pr-runner.yml actions to commit SHAs

* 🔒 pin trufflehog.yml actions to commit SHAs
2026-04-02 21:04:45 +05:30
Sayak Paul
131831ff20 Apply suggestions from code review
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-31 20:56:20 +05:30
Sayak Paul
3fc1a04526 Merge branch 'main' into profiling-workflow 2026-03-31 19:17:50 +05:30
sayakpaul
40c330a90d note on regional compilation 2026-03-31 19:17:35 +05:30
sayakpaul
fb6afa6da6 make important 2026-03-31 19:08:22 +05:30
sayakpaul
6cf142902a unavoidable gaps. 2026-03-31 19:07:58 +05:30
sayakpaul
3bdd529141 Merge branch 'main' into profiling-workflow 2026-03-31 09:54:38 +05:30
sayakpaul
40a525e784 table 2026-03-31 09:54:15 +05:30
sayakpaul
bfb19afd1e approach -> How the tooling works 2026-03-31 09:51:49 +05:30
sayakpaul
3ae7d9b4d7 add torch.compile link. 2026-03-31 09:50:36 +05:30
Sayak Paul
ed8241a394 Apply suggestions from code review
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-31 09:46:45 +05:30
sayakpaul
c642cd0e4f up 2026-03-30 09:02:49 +05:30
sayakpaul
1131acd6e1 cuda graphs. 2026-03-29 10:03:29 +05:30
sayakpaul
e26d5c6ee3 better title 2026-03-28 15:18:30 +05:30
sayakpaul
43e16fba40 up 2026-03-28 10:08:35 +05:30
sayakpaul
12ba8be720 add more traces. 2026-03-28 09:54:32 +05:30
sayakpaul
9ba98a2642 up 2026-03-27 17:27:11 +05:30
sayakpaul
142f417b66 more 2026-03-27 17:18:03 +05:30
sayakpaul
35437a897e wan fixes. 2026-03-27 17:01:37 +05:30
sayakpaul
a410b4958c up 2026-03-27 16:29:45 +05:30
sayakpaul
bfbaf079cd up 2026-03-27 13:39:49 +05:30
sayakpaul
bf5131fba9 propagate deletion. 2026-03-27 12:51:56 +05:30
sayakpaul
6a23a771aa improve readme. 2026-03-27 12:51:09 +05:30
sayakpaul
96506c85d0 cache hooks 2026-03-27 12:24:11 +05:30
sayakpaul
179fa51342 up 2026-03-27 11:44:21 +05:30
sayakpaul
60d4148529 add points. 2026-03-27 11:10:17 +05:30
sayakpaul
b2b6330a54 more clarification 2026-03-27 10:33:09 +05:30
sayakpaul
e4d6293b4d fix 2026-03-27 09:17:50 +05:30
sayakpaul
eddef12a54 fix 2026-03-27 09:13:39 +05:30
sayakpaul
af96109435 add a profiling worflow. 2026-03-26 17:01:41 +05:30
30 changed files with 1147 additions and 109 deletions

View File

@@ -148,5 +148,6 @@ ComponentSpec(
- [ ] Create pipeline class with `default_blocks_name`
- [ ] Assemble blocks in `modular_blocks_<model>.py`
- [ ] Wire up `__init__.py` with lazy imports
- [ ] Add `# auto_docstring` above all assembled blocks (SequentialPipelineBlocks, AutoPipelineBlocks, etc.), run `python utils/modular_auto_docstring.py --fix_and_overwrite`, and verify the generated docstrings — all parameters should have proper descriptions with no "TODO" placeholders indicating missing definitions
- [ ] Run `make style` and `make quality`
- [ ] Test all workflows for parity with reference

View File

@@ -28,7 +28,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -58,7 +58,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}

View File

@@ -25,14 +25,14 @@ jobs:
if: github.event_name == 'pull_request'
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
- name: Check out code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Find Changed Dockerfiles
id: file_changes
uses: jitterbit/get-changed-files@v1
uses: jitterbit/get-changed-files@b17fbb00bdc0c0f63fcf166580804b4d2cdc2a42 # v1
with:
format: "space-delimited"
token: ${{ secrets.GITHUB_TOKEN }}
@@ -99,16 +99,16 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # v3
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
with:
username: ${{ env.REGISTRY }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
uses: docker/build-push-action@10e90e3645eae34f1e60eeb005ba3a3d33f178e8 # v6
with:
no-cache: true
context: ./docker/${{ matrix.image-name }}
@@ -117,7 +117,7 @@ jobs:
- name: Post to a Slack channel
id: slack
uses: huggingface/hf-workflows/.github/actions/post-slack@main
uses: huggingface/hf-workflows/.github/actions/post-slack@a88e7fa2eaee28de5a4d6142381b1fb792349b67 # main
with:
# Slack channel id, channel name, or user id to post message.
# See also: https://api.slack.com/methods/chat.postMessage#channels

View File

@@ -14,7 +14,7 @@ on:
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
with:
commit_sha: ${{ github.sha }}
install_libgl1: true

View File

@@ -17,10 +17,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6
with:
python-version: '3.10'
@@ -39,7 +39,7 @@ jobs:
build:
needs: check-links
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
with:
commit_sha: ${{ github.event.pull_request.head.sha }}
pr_number: ${{ github.event.number }}

View File

@@ -10,7 +10,7 @@ on:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@dc6ca34688e6876c2dd18750719b44d177586c17 # v1
permissions:
security-events: write
packages: read

View File

@@ -28,7 +28,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -97,7 +97,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -119,7 +119,7 @@ jobs:
module: [models, schedulers, lora, others, single_file, examples]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
@@ -167,7 +167,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_compile_test_reports
path: reports
@@ -228,7 +228,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -263,7 +263,7 @@ jobs:
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_cuda_big_gpu_test_reports
path: reports
@@ -280,7 +280,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
@@ -321,7 +321,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -355,7 +355,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -391,7 +391,7 @@ jobs:
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
@@ -408,7 +408,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -441,7 +441,7 @@ jobs:
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
@@ -466,7 +466,7 @@ jobs:
image: diffusers/diffusers-pytorch-cpu
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
@@ -474,7 +474,7 @@ jobs:
run: mkdir -p combined_reports
- name: Download all test reports
uses: actions/download-artifact@v7
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7
with:
path: artifacts
@@ -500,7 +500,7 @@ jobs:
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
- name: Upload consolidated report
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
with:
name: consolidated_test_report
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
@@ -514,7 +514,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
# with:
# fetch-depth: 2
#
@@ -554,7 +554,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
# with:
# name: torch_mps_test_reports
# path: reports
@@ -570,7 +570,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
# with:
# fetch-depth: 2
#
@@ -610,7 +610,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6
# with:
# name: torch_mps_test_reports
# path: reports

View File

@@ -10,7 +10,7 @@ permissions:
jobs:
style:
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@main
uses: huggingface/huggingface_hub/.github/workflows/style-bot-action.yml@e000c1c89c65aee188041723456ac3a479416d4c # main
with:
python_quality_dependencies: "[quality]"
secrets:

View File

@@ -27,12 +27,12 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 2
- name: Tailscale # In order to be able to SSH when a test fails
uses: huggingface/tailscale-action@main
uses: huggingface/tailscale-action@7d53c9737e53934c30290b5524d1c9b4a7c98c8a # main
with:
authkey: ${{ secrets.TAILSCALE_SSH_AUTHKEY }}
slackChannel: ${{ secrets.SLACK_CIFEEDBACK_CHANNEL }}

View File

@@ -8,11 +8,11 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
uses: trufflesecurity/trufflehog@6bd2d14f7a4bc1e569fa3550efa7ec632a4fa67b # main
with:
extra_args: --results=verified,unknown

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- name: typos-action
uses: crate-ci/typos@v1.42.1
uses: crate-ci/typos@65120634e79d8374d1aa2f27e54baa0c364fff5a # v1.42.1

View File

@@ -8,7 +8,7 @@ on:
jobs:
build:
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@90b4ee2c10b81b5c1a6367c4e6fc9e2fb510a7e3 # main
with:
package_name: diffusers
secrets:

View File

@@ -112,6 +112,8 @@
title: ModularPipeline
- local: modular_diffusers/components_manager
title: ComponentsManager
- local: modular_diffusers/auto_docstring
title: Auto docstring and parameter templates
- local: modular_diffusers/custom_blocks
title: Building Custom Blocks
- local: modular_diffusers/mellon

View File

@@ -0,0 +1,157 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Auto docstring and parameter templates
Every [`~modular_pipelines.ModularPipelineBlocks`] has a `doc` property that is automatically generated from its `description`, `inputs`, `intermediate_outputs`, `expected_components`, and `expected_configs`. The auto docstring system keeps docstrings in sync with the block's actual interface. Parameter templates provide standardized descriptions for parameters that appear across many pipelines.
## Auto docstring
Modular pipeline blocks are composable — you can nest them, chain them in sequences, and rearrange them freely. Their docstrings follow the same pattern. When a [`~modular_pipelines.SequentialPipelineBlocks`] aggregates inputs and outputs from its sub-blocks, the documentation should update automatically without manual rewrites.
The `# auto_docstring` marker generates docstrings from the block's properties. Add it above a class definition to mark the class for automatic docstring generation.
```py
# auto_docstring
class FluxTextEncoderStep(SequentialPipelineBlocks):
...
```
Run the following command to generate and insert the docstrings.
```bash
python utils/modular_auto_docstring.py --fix_and_overwrite
```
The utility reads the block's `doc` property and inserts it as the class docstring.
```py
# auto_docstring
class FluxTextEncoderStep(SequentialPipelineBlocks):
"""
Text input processing step that standardizes text embeddings for the pipeline.
Inputs:
prompt_embeds (`torch.Tensor`) *required*:
text embeddings used to guide the image generation.
...
Outputs:
prompt_embeds (`torch.Tensor`):
text embeddings used to guide the image generation.
...
"""
```
You can also check without overwriting, or target a specific file or directory.
```bash
# Check that all marked classes have up-to-date docstrings
python utils/modular_auto_docstring.py
# Check a specific file or directory
python utils/modular_auto_docstring.py src/diffusers/modular_pipelines/flux/
```
If any marked class is missing a docstring, the check fails and lists the classes that need updating.
```
Found the following # auto_docstring markers that need docstrings:
- src/diffusers/modular_pipelines/flux/encoders.py: FluxTextEncoderStep at line 42
Run `python utils/modular_auto_docstring.py --fix_and_overwrite` to fix them.
```
## Parameter templates
`InputParam` and `OutputParam` define a block's inputs and outputs. Create them directly or use `.template()` for standardized definitions of common parameters like `prompt`, `num_inference_steps`, or `latents`.
### InputParam
[`~modular_pipelines.InputParam`] describes a single input to a block.
| Field | Type | Description |
|---|---|---|
| `name` | `str` | Name of the parameter |
| `type_hint` | `Any` | Type annotation (e.g., `str`, `torch.Tensor`) |
| `default` | `Any` | Default value (if not set, parameter has no default) |
| `required` | `bool` | Whether the parameter is required |
| `description` | `str` | Human-readable description |
| `kwargs_type` | `str` | Group name for related parameters (e.g., `"denoiser_input_fields"`) |
| `metadata` | `dict` | Arbitrary additional information |
#### Creating InputParam directly
```py
from diffusers.modular_pipelines import InputParam
InputParam(
name="guidance_scale",
type_hint=float,
default=7.5,
description="Scale for classifier-free guidance.",
)
```
#### Using a template
```py
InputParam.template("prompt")
# Equivalent to:
# InputParam(name="prompt", type_hint=str, required=True,
# description="The prompt or prompts to guide image generation.")
```
Templates set `name`, `type_hint`, `default`, `required`, and `description` automatically. Override any field or add context with the `note` parameter.
```py
# Override the default value
InputParam.template("num_inference_steps", default=28)
# Add a note to the description
InputParam.template("prompt_embeds", note="batch-expanded")
# description becomes: "text embeddings used to guide the image generation. ... (batch-expanded)"
```
### OutputParam
[`~modular_pipelines.OutputParam`] describes a single output from a block.
| Field | Type | Description |
|---|---|---|
| `name` | `str` | Name of the parameter |
| `type_hint` | `Any` | Type annotation |
| `description` | `str` | Human-readable description |
| `kwargs_type` | `str` | Group name for related parameters |
| `metadata` | `dict` | Arbitrary additional information |
#### Creating OutputParam directly
```py
from diffusers.modular_pipelines import OutputParam
OutputParam(name="image_latents", type_hint=torch.Tensor, description="Encoded image latents.")
```
#### Using a template
```py
OutputParam.template("latents")
# Add a note to the description
OutputParam.template("prompt_embeds", note="batch-expanded")
```
## Available templates
`INPUT_PARAM_TEMPLATES` and `OUTPUT_PARAM_TEMPLATES` are defined in [modular_pipeline_utils.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/modular_pipelines/modular_pipeline_utils.py). They include common parameters like `prompt`, `image`, `num_inference_steps`, `latents`, `prompt_embeds`, and more. Refer to the source for the full list of available template names.

View File

@@ -100,7 +100,7 @@ accelerate launch train_text_to_image_sdxl.py \
The training script is also similar to the [Text-to-image](text2image#training-script) training guide, but it's been modified to support SDXL training. This guide will focus on the code that is unique to the SDXL training script.
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
It starts by creating functions to [tokenize the prompts](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L478) to calculate the prompt embeddings, and to compute the image embeddings with the [VAE](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L519). Next, you'll create a function to [generate the timesteps weights](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L531) depending on the number of timesteps and the timestep bias strategy to apply.
Within the [`main()`](https://github.com/huggingface/diffusers/blob/aab6de22c33cc01fb7bc81c0807d6109e2c998c9/examples/text_to_image/train_text_to_image_sdxl.py#L572) function, in addition to loading a tokenizer, the script loads a second tokenizer and text encoder because the SDXL architecture uses two of each:
@@ -250,5 +250,5 @@ print(f'Inference time is {time()-start} sec after compilation')
Congratulations on training a SDXL model! To learn more about how to use your new model, the following guides may be helpful:
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use it's refiner model, and the different types of micro-conditionings.
- Read the [Stable Diffusion XL](../using-diffusers/sdxl) guide to learn how to use it for a variety of different tasks (text-to-image, image-to-image, inpainting), how to use its refiner model, and the different types of micro-conditionings.
- Check out the [DreamBooth](dreambooth) and [LoRA](lora) training guides to learn how to train a personalized SDXL model with just a few example images. These two training techniques can even be combined!

View File

@@ -111,7 +111,7 @@ It conditions on a monocular depth estimate of the original image.
[Paper](https://huggingface.co/papers/2302.08113)
MultiDiffusion Panorama defines a new generation process over a pre-trained diffusion model. This process binds together multiple diffusion generation methods that can be readily applied to generate high quality and diverse images. Results adhere to user-provided controls, such as desired aspect ratio (e.g., panorama), and spatial guiding signals, ranging from tight segmentation masks to bounding boxes.
MultiDiffusion Panorama allows to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
MultiDiffusion Panorama allows you to generate high-quality images at arbitrary aspect ratios (e.g., panoramas).
## Fine-tuning your own models

View File

@@ -60,7 +60,7 @@ print(np.abs(image).sum())
</hfoption>
</hfoptions>
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because it's *state* has changed.
The `Generator` object should be passed to the pipeline instead of an integer seed. `Generator` maintains a *random state* that is consumed and modified when used. Once consumed, the same `Generator` object produces different results in subsequent calls, even across different pipelines, because its *state* has changed.
```py
generator = torch.manual_seed(0)

View File

@@ -0,0 +1,320 @@
# Profiling a `DiffusionPipeline` with the PyTorch Profiler
Education materials to strategically profile pipelines to potentially improve their
runtime with `torch.compile`. To set these pipelines up for success with `torch.compile`,
we often have to get rid of device-to-host (DtoH) syncs, CPU overheads, kernel launch delays, and
graph breaks. In this context, profiling serves that purpose for us.
Thanks to Claude Code for paircoding! We acknowledge the [Claude of OSS](https://claude.com/contact-sales/claude-for-oss) support provided to us.
## Table of contents
* [Context](#context)
* [Target pipelines](#target-pipelines)
* [How the tooling works](#how-the-tooling-works)
* [Verification](#verification)
* [Interpretation](#interpreting-traces-in-perfetto-ui)
* [Taking profiling-guided steps for improvements](#afterwards)
Jump to the "Verification" section to get started right away.
## Context
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial when using [`torch.compile`](https://docs.pytorch.org/docs/stable/generated/torch.compile.html). The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses [`torch.profiler`](https://docs.pytorch.org/docs/stable/profiler.html) with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
## Target Pipelines
| Pipeline | Type | Checkpoint | Steps |
|----------|------|-----------|-------|
| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 2 |
| `Flux2KleinPipeline` | text-to-image | `black-forest-labs/FLUX.2-klein-base-9B` | 2 |
| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 2 |
| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 2 |
| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 2 |
> [!NOTE]
> We use realistic inference call hyperparameters that mimic how these pipelines will be actually used. This
> includes using classifier-free guidance (where applicable), reasonable dimensions such 1024x1024, etc.
> But we keep the number of inference steps to a bare minimum.
## How the Tooling Works
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome JSON trace.
### New Files
```bash
profiling_utils.py # Annotation helper + profiler setup
profiling_pipelines.py # CLI entry point with pipeline configs
run_profiling.sh # Bulk launch runs for multiple pipelines
```
### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure
**A) `annotate(func, name)` helper** (same pattern as flux-fast):
```python
def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)
return wrapper
```
**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline:
- `pipe.transformer.forward``"transformer_forward"`
- `pipe.vae.decode``"vae_decode"` (if present)
- `pipe.vae.encode``"vae_encode"` (if present)
- `pipe.scheduler.step``"scheduler_step"`
- `pipe.encode_prompt``"encode_prompt"` (if present, for full-pipeline profiling)
This is non-invasive — it monkey-patches bound methods without modifying source.
**C) `PipelineProfiler` class:**
- `__init__(pipeline_config, output_dir, mode="eager"|"compile")`
- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()`
- `run()`:
1. Warm up with 1 unannotated run
2. Profile 1 run with `torch.profiler.profile`:
- `activities=[CPU, CUDA]`
- `record_shapes=True`
- `profile_memory=True`
- `with_stack=True`
3. Export Chrome trace JSON
4. Print `key_averages()` summary table (sorted by CUDA time) to stdout
### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs
**Pipeline config registry** — each entry specifies:
- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype`
- `call_kwargs` with pipeline-specific defaults:
| Pipeline | Resolution | Frames | Steps | Extra |
|----------|-----------|--------|-------|-------|
| Flux | 1024x1024 | — | 2 | `guidance_scale=3.5` |
| Flux2Klein | 1024x1024 | — | 2 | `guidance_scale=3.5` |
| Wan | 480x832 | 81 | 2 | — |
| LTX2 | 768x512 | 121 | 2 | `guidance_scale=4.0` |
| QwenImage | 1024x1024 | — | 2 | `true_cfg_scale=4.0` |
All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces).
**CLI flags:**
- `--pipeline flux|flux2|wan|ltx2|qwenimage|all`
- `--mode eager|compile|both`
- `--output_dir profiling_results/`
- `--num_steps N` (override, default 4)
- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE)
- `--compile_mode default|reduce-overhead|max-autotune`
- `--compile_regional` flag (uses [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html) to compile only the transformer forward pass instead of the full pipeline — faster compile times, ideal for iterative profiling)
- `--compile_fullgraph` flag
**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary.
### Step 3: Known Sync Issues to Validate
The profiling should surface these known/suspected issues:
1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines.
2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace.
3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces.
## Verification
1. Run: `python examples/profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 2`
2. Verify `profiling_results/flux_eager.json` is produced
3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm:
- `transformer_forward` and `scheduler_step` annotations visible
- CPU and CUDA timelines present
- Stack traces visible on CPU events
4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels
You can also use the `run_profiling.sh` script to bulk launch runs for different pipelines.
## Interpreting Traces in Perfetto UI
Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom). In Perfetto, the CPU row is typically labeled with the process/thread name (e.g., `python (PID)` or `MainThread`) and appears at the top. The CUDA row is labeled `GPU 0` (or similar) and appears below the CPU rows.
**Navigation:** Use `W` to zoom in, `S` to zoom out, and `A`/`D` to pan left/right. You can also scroll to zoom and click-drag to pan. Use `Shift+scroll` to scroll vertically through rows.
> [!IMPORTANT]
> To keep the profiling iterations fast, we always use [regional compilation](https://pytorch.org/tutorials/recipes/regional_compilation.html). The observations below would largely still apply for full model
compilation, too.
### What to look for
**1. Gaps between CUDA kernels**
Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes:
- Python overhead between ops (visible as CPU slices in the CPU row during the gap)
- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed
> [!IMPORTANT]
> No bubbles/gaps is ideal, but for small shapes (small model, small batch size, or both) some bubbles could be unavoidable.
**2. CPU stalls (DtoH syncs)**
These appear on the **CPU row** (not the CUDA row) — they are CPU-side blocking calls that wait for the GPU to finish. Look for long slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. To find them: zoom into the CPU row during a denoising step and look for unusually wide slices, or use Perfetto's search bar (press `/`) and type `cudaStreamSynchronize` to jump directly to matching events. Click on a slice — if `with_stack=True` was enabled, the bottom panel ("Current Selection") shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler).
**3. Annotated regions**
Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly:
- Measure how long each phase takes (click a span to see duration)
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
- Spot unexpected CPU work between annotated regions
**4. Eager vs compile comparison**
Open both traces side by side (two Perfetto tabs). Key differences to look for:
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
- **CUDA kernel count per step**: to compare, zoom into a single `transformer_forward` span on the CUDA row and count the distinct kernel slices within it. In eager mode you'll typically see many narrow slices (one per op); in compile mode these fuse into fewer, wider slices. A quick way to estimate: select a time range covering one denoising step on the CUDA row — Perfetto shows the number of slices in the selection summary at the bottom. If compile mode shows a similar kernel count to eager, fusion isn't happening effectively (likely due to graph breaks).
- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details
**5. Memory timeline**
In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not.
**6. Kernel launch latency**
Each CUDA kernel is launched from the CPU. The CPU-side launch calls (`cudaLaunchKernel`) appear as small slices on the **CPU row** — zoom in closely to a denoising step to see them. The corresponding GPU-side kernel executions appear on the **CUDA row** directly below. You can also use Perfetto's search bar (`/`) and type `cudaLaunchKernel` to find them. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution:
- The launch queue may be starved because of excessive Python work between ops
- There may be implicit syncs forcing serialization
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it (there should be an arrow pointing from the CPU launch slice to the GPU kernel slice). The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
### Quick checklist per pipeline
| Question | Where to look | Healthy | Unhealthy |
|----------|--------------|---------|-----------|
| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us |
| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step |
| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time |
| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager |
| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU |
| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step |
## What Profiling Revealed and Fixes
As one would expect the trace with compilation should show fewer kernel launches than its eager counterpart.
_(Unless otherwise specified, the traces below were obtained with **Flux2**.)_
<table>
<tr>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.03.39%E2%80%AFAM.png" alt="Image 1"><br>
<em>Without compile</em>
</td>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.05.06%E2%80%AFAM.png" alt="Image 2"><br>
<em>With compile</em>
</td>
</tr>
</table>
### Spotting gaps between launches
Then a reasonable next step is to spot frequent gaps between kernel executions. In the compiled
case, we don't spot any on the surface. But if we zoom in, some become apparent.
<table>
<tr>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2011.16.42%E2%80%AFAM.png" alt="Image 1"><br>
<em>Very small visible gaps in between compiled regions</em>
</td>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Flux2-Klein/Screenshot%202026-03-27%20at%2010.24.34%E2%80%AFAM.png" alt="Image 2"><br>
<em>Gaps become more visible when zoomed in</em>
</td>
</tr>
</table>
So, we provided the profile trace file (with compilation) to Claude, asked it to find the instances of
"cudaStreamSynchronize" and "cudaDeviceSynchronize", and to come up with some potential fixes.
Claude came back with the following:
```
Issue 1 — Gap between transformer forwards:
- Root cause: tqdm progress bar update() calls between steps add CPU overhead (I/O, time calculations)
- Fix: profiling/profiling_utils.py — added pipe.set_progress_bar_config(disable=True) during profiling setup.
This eliminates the tqdm overhead from the trace. (The remaining gap from scheduler step + Python dispatch is
inherent to eager-mode execution and should shrink significantly under torch.compile.)
Issue 2 — cudaStreamSynchronize during last transformer forward:
- Root cause: _unpack_latents_with_ids() (called right after the denoising loop) computes h = torch.max(h_ids) +
1 and w = torch.max(w_ids) + 1 on GPU tensors, then uses them as shape args for torch.zeros((h * w, ch), ...).
This triggers an implicit .item() DtoH sync, blocking the CPU while the GPU is still finishing the last
transformer forward's kernels.
- Fix: Added height/width parameters to _unpack_latents_with_ids(), pre-computed from the known pixel dimensions
at the call site.
```
The changes looked reasonable based on our past experience. So, we asked Claude to apply these changes to [`pipeline_flux2_klein.py`](../../src/diffusers/pipelines/flux2/pipeline_flux2_klein.py). We then profiled
the updated pipeline. It still didn't eliminate the gaps as expected so, we fed that back to Claude and
it spotted something more crucial.
Under the [`cache_context`](https://github.com/huggingface/diffusers/blob/f2be8bd6b3dc4035bd989dc467f15d86bf3c9c12/src/diffusers/pipelines/flux2/pipeline_flux2_klein.py#L842) manager, there is a call to `_set_context()` upon
enters and exits. It calls `named_modules()` on the entire underlying model (in this case the Flux2 Klein DiT).
For large models, when they are invoked iteratively like our case, it adds to the latency because it involes traversing hundreds of submodules.
The fix was to build a list of hooked child registries once on the first call and cache it in `_child_registries_cache`. This way, the subsequent calls would return the cached list directly without
any traversal. With the fix applied, the improvements were visible.
| | Before | After |
|------------------------|------------------------------|-----------------------------|
| `_set_context` total | 21.6ms (8 calls) | 0.0ms (8 calls) |
| `cache_context` total | 21.7ms | 0.1ms |
| CPU gaps | 5,523us / 8,007us / 5,508us | 158us / 2,777us / 136us |
> [!NOTE]
> The fixes mentioned above and below are available in [this PR](https://github.com/huggingface/diffusers/pull/13356).
### DtoH syncs
We also profiled the **Wan** model and uncovered problems related to CPU DtoH syncs. Below is an
overview.
First, there was a dynamo cache lookup delay making the GPU idle as reported [in this PR](https://github.com/huggingface/diffusers/pull/11696). So, the fix was to call `self.scheduler.set_begin_index(0)` before
the denoising loop. This tells the scheduler the starting index is 0, so `_init_step_index()` skips the `nonzero().item()` (which was causing the sync) path entirely. This fix eliminated the below ~2.3s GPU idle time completely:
![GPU idle](https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%202026-03-27%20at%205.56.39%E2%80%AFPM.png)
The UniPC scheduler (used in Wan) creates small constant tensors via `torch.tensor([0.5], dtype=x.dtype, device=device)` during `step()`. This triggers a "cudaMemcpyAsync + cudaStreamSynchronize" to copy
the value from CPU to GPU. The sync itself is normally fast (~6us), but it forces the CPU to wait
until all pending GPU kernels finish before proceeding. Under torch.compile, the GPU has many queued
kernels, so this tiny sync balloons to 2.3s.
**Fix**: Replace with `torch.ones(1, dtype=x.dtype, device=device) * 0.5`. `torch.ones` allocates on GPU via "cudaMemsetAsync" (no sync), and `* 0.5` is a CUDA kernel launch (no sync). Same result, zero CPU-GPU synchronization. The duration of the scheduling step before and after this fix confirms this:
<table>
<tr>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%25202026-03-27%2520at%25206.04.06%25E2%2580%25AFPM.png" alt="Image 1"><br>
<em>CPU<->GPU sync</em>
</td>
<td align="center">
<img src="https://huggingface.co/datasets/sayakpaul/torch-profiling-trace-diffusers/resolve/main/Wan/Screenshot%25202026-03-27%2520at%25206.04.29%25E2%2580%25AFPM.png" alt="Image 2"><br>
<em>Almost no sync</em>
</td>
</tr>
</table>
### Notes
* As mentioned above, we profiled with regional compilation so it's possible that
there are still some gaps outside the compiled regions. A full compilation
will likely mitigate it. In case it doesn't, the above observations could
be useful to mitigate that.
* Use of CUDA Graphs can also help mitigate CPU overhead related issues. CUDA Graphs can be enabled by setting the `torch.compile` mode to `"reduce-overhead"` or `"max-autotune"`.
* Diffusers' integration of `torch.compile` is documented [here](https://huggingface.co/docs/diffusers/main/en/optimization/fp16#torchcompile).

View File

@@ -0,0 +1,181 @@
"""
Profile diffusers pipelines with torch.profiler.
Usage:
python profiling/profiling_pipelines.py --pipeline flux --mode eager
python profiling/profiling_pipelines.py --pipeline flux --mode compile
python profiling/profiling_pipelines.py --pipeline flux --mode both
python profiling/profiling_pipelines.py --pipeline all --mode eager
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
"""
import argparse
import copy
import logging
import torch
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
PROMPT = "A cat holding a sign that says hello world"
def build_registry():
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
from diffusers import Flux2KleinPipeline, FluxPipeline, LTX2Pipeline, QwenImagePipeline, WanPipeline
return {
"flux": PipelineProfilingConfig(
name="flux",
pipeline_cls=FluxPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"flux2": PipelineProfilingConfig(
name="flux2",
pipeline_cls=Flux2KleinPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"wan": PipelineProfilingConfig(
name="wan",
pipeline_cls=WanPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
"height": 480,
"width": 832,
"num_frames": 81,
"num_inference_steps": 4,
"output_type": "latent",
},
),
"ltx2": PipelineProfilingConfig(
name="ltx2",
pipeline_cls=LTX2Pipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Lightricks/LTX-2",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"height": 512,
"width": 768,
"num_frames": 121,
"num_inference_steps": 4,
"guidance_scale": 4.0,
"output_type": "latent",
},
),
"qwenimage": PipelineProfilingConfig(
name="qwenimage",
pipeline_cls=QwenImagePipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": " ",
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"true_cfg_scale": 4.0,
"output_type": "latent",
},
),
}
def main():
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
parser.add_argument(
"--pipeline",
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
required=True,
help="Which pipeline to profile",
)
parser.add_argument(
"--mode",
choices=["eager", "compile", "both"],
default="eager",
help="Run in eager mode, compile mode, or both",
)
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
parser.add_argument(
"--compile_mode",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode",
)
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
parser.add_argument(
"--compile_regional",
action="store_true",
help="Use compile_repeated_blocks() instead of full model compile",
)
args = parser.parse_args()
registry = build_registry()
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
for pipeline_name in pipeline_names:
for mode in modes:
config = copy.deepcopy(registry[pipeline_name])
# Apply overrides
if args.num_steps is not None:
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
if args.full_decode:
config.pipeline_call_kwargs["output_type"] = "pil"
if mode == "compile":
config.compile_kwargs = {
"fullgraph": args.compile_fullgraph,
"mode": args.compile_mode,
}
config.compile_regional = args.compile_regional
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
profiler = PipelineProfiler(config, args.output_dir)
try:
trace_file = profiler.run()
logger.info(f"Done: {trace_file}")
except Exception as e:
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,148 @@
import functools
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.profiler
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)
return wrapper
def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.
Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
("vae", "decode", "vae_decode"),
("vae", "encode", "vae_encode"),
("scheduler", "step", "scheduler_step"),
]
# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
if component is None:
continue
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))
# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
@dataclass
class PipelineProfilingConfig:
name: str
pipeline_cls: Any
pipeline_init_kwargs: dict[str, Any]
pipeline_call_kwargs: dict[str, Any]
compile_kwargs: dict[str, Any] | None = field(default=None)
compile_regional: bool = False
class PipelineProfiler:
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
self.config = config
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
def setup_pipeline(self):
"""Load the pipeline from pretrained, optionally compile, and annotate."""
logger.info(f"Loading pipeline: {self.config.name}")
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
pipe.to("cuda")
if self.config.compile_kwargs:
if self.config.compile_regional:
logger.info(
f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}"
)
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
else:
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile(**self.config.compile_kwargs)
# Disable tqdm progress bar to avoid CPU overhead / IO between steps
pipe.set_progress_bar_config(disable=True)
annotate_pipeline(pipe)
return pipe
def run(self):
"""Execute the profiling run: warmup, then profile one pipeline call."""
pipe = self.setup_pipeline()
flush()
mode = "compile" if self.config.compile_kwargs else "eager"
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
logger.info("Running warmup...")
pipe(**self.config.pipeline_call_kwargs)
flush()
# Profile
logger.info("Running profiled iteration...")
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("pipeline_call"):
pipe(**self.config.pipeline_call_kwargs)
# Export trace
prof.export_chrome_trace(trace_file)
logger.info(f"Chrome trace saved to: {trace_file}")
# Print summary
print("\n" + "=" * 80)
print(f"Profile summary: {self.config.name} ({mode})")
print("=" * 80)
print(
prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
)
)
# Cleanup
pipe.to("cpu")
del pipe
flush()
return trace_file

View File

@@ -0,0 +1,46 @@
#!/bin/bash
# Run profiling across all pipelines in eager and compile (regional) modes.
#
# Usage:
# bash profiling/run_profiling.sh
# bash profiling/run_profiling.sh --output_dir my_results
set -euo pipefail
OUTPUT_DIR="profiling_results"
while [[ $# -gt 0 ]]; do
case "$1" in
--output_dir) OUTPUT_DIR="$2"; shift 2 ;;
*) echo "Unknown arg: $1"; exit 1 ;;
esac
done
NUM_STEPS=2
# PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
PIPELINES=("wan")
MODES=("eager" "compile")
for pipeline in "${PIPELINES[@]}"; do
for mode in "${MODES[@]}"; do
echo "============================================================"
echo "Profiling: ${pipeline} | mode: ${mode}"
echo "============================================================"
COMPILE_ARGS=""
if [ "$mode" = "compile" ]; then
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
fi
python profiling/profiling_pipelines.py \
--pipeline "$pipeline" \
--mode "$mode" \
--output_dir "$OUTPUT_DIR" \
--num_steps "$NUM_STEPS" \
$COMPILE_ARGS
echo ""
done
done
echo "============================================================"
echo "All traces saved to: ${OUTPUT_DIR}/"
echo "============================================================"

View File

@@ -271,12 +271,31 @@ class HookRegistry:
if hook._is_stateful:
hook._set_context(self._module_ref, name)
for registry in self._get_child_registries():
registry._set_context(name)
def _get_child_registries(self) -> list["HookRegistry"]:
"""Return registries of child modules, using a cached list when available.
The cache is built on first call and reused for subsequent calls. This avoids the cost of walking the full
module tree via named_modules() on every _set_context call, which is significant for large models (e.g. ~2.7ms
per call on Flux2).
"""
if not hasattr(self, "_child_registries_cache"):
self._child_registries_cache = None
if self._child_registries_cache is not None:
return self._child_registries_cache
registries = []
for module_name, module in unwrap_module(self._module_ref).named_modules():
if module_name == "":
continue
module = unwrap_module(module)
if hasattr(module, "_diffusers_hook"):
module._diffusers_hook._set_context(name)
registries.append(module._diffusers_hook)
self._child_registries_cache = registries
return registries
def __repr__(self) -> str:
registry_repr = ""

View File

@@ -423,7 +423,9 @@ def dispatch_attention_fn(
**attention_kwargs,
"_parallel_config": parallel_config,
}
if is_torch_version(">=", "2.5.0"):
# Equivalent to `is_torch_version(">=", "2.5.0")` — use module-level constant to avoid
# Dynamo tracing into the lru_cache-wrapped `is_torch_version` during torch.compile.
if _CAN_USE_FLEX_ATTN:
kwargs["enable_gqa"] = enable_gqa
if _AttentionBackendRegistry._checks_enabled:

View File

@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import torch
from ...models import UNet2DModel
@@ -21,6 +23,9 @@ from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
logger = logging.getLogger(__name__)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
@@ -129,6 +134,13 @@ class DDIMPipeline(DiffusionPipeline):
else:
image_shape = (batch_size, self.unet.config.in_channels, *self.unet.config.sample_size)
if not 0.0 <= eta <= 1.0:
logger.warning(
f"`eta` should be between 0 and 1 (inclusive), but received {eta}. "
"A value of 0 corresponds to DDIM and 1 corresponds to DDPM. "
"Unexpected results may occur for values outside this range."
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"

View File

@@ -396,8 +396,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
return latents
@staticmethod
# Copied from diffusers.pipelines.flux2.pipeline_flux2.Flux2Pipeline._unpack_latents_with_ids
def _unpack_latents_with_ids(x: torch.Tensor, x_ids: torch.Tensor) -> list[torch.Tensor]:
def _unpack_latents_with_ids(
x: torch.Tensor, x_ids: torch.Tensor, height: int | None = None, width: int | None = None
) -> list[torch.Tensor]:
"""
using position ids to scatter tokens into place
"""
@@ -407,8 +408,9 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
h_ids = pos[:, 1].to(torch.int64)
w_ids = pos[:, 2].to(torch.int64)
h = torch.max(h_ids) + 1
w = torch.max(w_ids) + 1
# Use provided height/width to avoid DtoH sync from torch.max().item()
h = height if height is not None else torch.max(h_ids) + 1
w = width if width is not None else torch.max(w_ids) + 1
flat_ids = h_ids * w + w_ids
@@ -895,7 +897,10 @@ class Flux2KleinPipeline(DiffusionPipeline, Flux2LoraLoaderMixin):
self._current_timestep = None
latents = self._unpack_latents_with_ids(latents, latent_ids)
# Pass pre-computed latent height/width to avoid DtoH sync from torch.max().item()
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
latents = self._unpack_latents_with_ids(latents, latent_ids, latent_height // 2, latent_width // 2)
latents_bn_mean = self.vae.bn.running_mean.view(1, -1, 1, 1).to(latents.device, latents.dtype)
latents_bn_std = torch.sqrt(self.vae.bn.running_var.view(1, -1, 1, 1) + self.vae.config.batch_norm_eps).to(

View File

@@ -574,6 +574,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
# We set the index here to remove DtoH sync, helpful especially during compilation.
# Check out more details here: https://github.com/huggingface/diffusers/pull/11696
self.scheduler.set_begin_index(0)
if self.config.boundary_ratio is not None:
boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
else:

View File

@@ -903,8 +903,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
rks.append(torch.ones((), device=device))
rks = torch.stack(rks)
R = []
b = []
@@ -929,13 +929,13 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
rhos_p = torch.ones(1, dtype=x.dtype, device=device) * 0.5
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else:
@@ -1038,8 +1038,8 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
rks.append(torch.ones((), device=device))
rks = torch.stack(rks)
R = []
b = []
@@ -1064,7 +1064,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
b = torch.stack(b) if len(b) > 0 else torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
@@ -1073,7 +1073,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
rhos_c = torch.ones(1, dtype=x.dtype, device=device) * 0.5
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)

View File

@@ -347,7 +347,17 @@ def lru_cache_unless_export(maxsize=128, typed=False):
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
compiler = getattr(torch, "compiler", None)
is_exporting = bool(compiler and hasattr(compiler, "is_exporting") and compiler.is_exporting())
is_compiling = bool(compiler and hasattr(compiler, "is_compiling") and compiler.is_compiling())
# Fallback for older builds where compiler.is_compiling is unavailable.
if not is_compiling:
dynamo = getattr(torch, "_dynamo", None)
if dynamo is not None and hasattr(dynamo, "is_compiling"):
is_compiling = dynamo.is_compiling()
if is_exporting or is_compiling:
return fn(*args, **kwargs)
return cached(*args, **kwargs)

View File

@@ -13,24 +13,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import pytest
import torch
from diffusers import AutoencoderKLWan
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLWan
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLWanTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLWan
def get_autoencoder_kl_wan_config(self):
@property
def output_shape(self):
return (3, 9, 16, 16)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self):
return {
"base_dim": 3,
"z_dim": 16,
@@ -39,54 +49,40 @@ class AutoencoderKLWanTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.T
"temperal_downsample": [False, True, True],
}
@property
def dummy_input(self):
def get_dummy_inputs(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
image = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": image}
@property
def dummy_input_tiling(self):
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (128, 128)
image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 9, 16, 16)
class TestAutoencoderKLWan(AutoencoderKLWanTesterConfig, ModelTesterMixin):
base_precision = 1e-2
@property
def output_shape(self):
return (3, 9, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_wan_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLWanTraining(AutoencoderKLWanTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLWan."""
def prepare_init_args_and_inputs_for_tiling(self):
init_dict = self.get_autoencoder_kl_wan_config()
inputs_dict = self.dummy_input_tiling
return init_dict, inputs_dict
@unittest.skip("Gradient checkpointing has not been implemented yet")
@pytest.mark.skip(reason="Gradient checkpointing has not been implemented yet")
def test_gradient_checkpointing_is_applied(self):
pass
@unittest.skip("Test not supported")
def test_forward_with_norm_groups(self):
class TestAutoencoderKLWanMemory(AutoencoderKLWanTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLWan."""
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_memory(self):
pass
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_inference(self):
pass
@unittest.skip("RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
@pytest.mark.skip(reason="RuntimeError: fill_out not implemented for 'Float8_e4m3fn'")
def test_layerwise_casting_training(self):
pass
class TestAutoencoderKLWanSlicingTiling(AutoencoderKLWanTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLWan."""

View File

@@ -145,3 +145,138 @@ class AutoencoderTesterMixin:
output_without_slicing.detach().cpu().numpy().all(),
output_without_slicing_2.detach().cpu().numpy().all(),
), "Without slicing outputs should match with the outputs when slicing is manually disabled."
class NewAutoencoderTesterMixin:
@staticmethod
def _accepts_generator(model):
model_sig = inspect.signature(model.forward)
accepts_generator = "generator" in model_sig.parameters
return accepts_generator
@staticmethod
def _accepts_norm_num_groups(model_class):
model_sig = inspect.signature(model_class.__init__)
accepts_norm_groups = "norm_num_groups" in model_sig.parameters
return accepts_norm_groups
def test_forward_with_norm_groups(self):
if not self._accepts_norm_num_groups(self.model_class):
pytest.skip(f"Test not supported for {self.model_class.__name__}")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 32)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output is not None
expected_shape = inputs_dict["sample"].shape
assert output.shape == expected_shape, "Input and output shapes do not match"
def test_enable_disable_tiling(self):
if not hasattr(self.model_class, "enable_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_tiling"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
with torch.no_grad():
torch.manual_seed(0)
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling = model(**inputs_dict)[0]
if isinstance(output_without_tiling, DecoderOutput):
output_without_tiling = output_without_tiling.sample
torch.manual_seed(0)
model.enable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_tiling = model(**inputs_dict)[0]
if isinstance(output_with_tiling, DecoderOutput):
output_with_tiling = output_with_tiling.sample
assert (output_without_tiling.cpu() - output_with_tiling.cpu()).max() < 0.5, (
"VAE tiling should not affect the inference results"
)
torch.manual_seed(0)
model.disable_tiling()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_tiling_2 = model(**inputs_dict)[0]
if isinstance(output_without_tiling_2, DecoderOutput):
output_without_tiling_2 = output_without_tiling_2.sample
assert torch.allclose(output_without_tiling.cpu(), output_without_tiling_2.cpu()), (
"Without tiling outputs should match with the outputs when tiling is manually disabled."
)
def test_enable_disable_slicing(self):
if not hasattr(self.model_class, "enable_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support slicing.")
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
torch.manual_seed(0)
model = self.model_class(**init_dict).to(torch_device)
if not hasattr(model, "use_slicing"):
pytest.skip(f"Skipping test as {self.model_class.__name__} doesn't support tiling.")
inputs_dict.update({"return_dict": False})
_ = inputs_dict.pop("generator", None)
accepts_generator = self._accepts_generator(model)
with torch.no_grad():
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
torch.manual_seed(0)
output_without_slicing = model(**inputs_dict)[0]
if isinstance(output_without_slicing, DecoderOutput):
output_without_slicing = output_without_slicing.sample
torch.manual_seed(0)
model.enable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_with_slicing = model(**inputs_dict)[0]
if isinstance(output_with_slicing, DecoderOutput):
output_with_slicing = output_with_slicing.sample
assert (output_without_slicing.cpu() - output_with_slicing.cpu()).max() < 0.5, (
"VAE slicing should not affect the inference results"
)
torch.manual_seed(0)
model.disable_slicing()
if accepts_generator:
inputs_dict["generator"] = torch.manual_seed(0)
output_without_slicing_2 = model(**inputs_dict)[0]
if isinstance(output_without_slicing_2, DecoderOutput):
output_without_slicing_2 = output_without_slicing_2.sample
assert torch.allclose(output_without_slicing.cpu(), output_without_slicing_2.cpu()), (
"Without slicing outputs should match with the outputs when slicing is manually disabled."
)