Compare commits

..

56 Commits

Author SHA1 Message Date
sayakpaul
c92cf75dfe up 2025-12-12 13:41:01 +05:30
Wang, Yi
218b17040f support CP in native flash attention (#12829)
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-12-12 13:40:50 +05:30
Sayak Paul
a7c7a270f6 [lora] Remove lora docs unneeded and add " # Copied from ..." (#12824)
* remove unneeded docs on load_lora_weights().

* remove more.

* up[

* up

* up
2025-12-12 13:40:50 +05:30
Sayak Paul
5455dd58e8 Update distributed_inference.md to correct syntax (#12827) 2025-12-12 13:40:50 +05:30
Sayak Paul
07084ef036 post release 0.36.0 (#12804)
* post release 0.36.0

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2025-12-12 13:40:50 +05:30
Sayak Paul
9e15c576cb [docs] improve distributed inference cp docs. (#12810)
* improve distributed inference cp docs.

* Apply suggestions from code review

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-12-12 13:40:50 +05:30
Dhruv Nair
8ddba1e082 [WIP] Add Flux2 modular (#12763)
* update

* update

* update

* update

* update

* update

* update

* update

* update

* update
2025-12-12 13:40:50 +05:30
Sayak Paul
d1b8202e42 Fix Qwen Edit Plus modular for multi-image input (#12601)
* try to fix qwen edit plus multi images (modular)

* up

* up

* test

* up

* up
2025-12-12 13:40:49 +05:30
YiYi Xu
f7439c30c9 [Modular]z-image (#12808)
* initiL

* up up

* fix: z_image -> z-image

* style

* copy

* fix more

* some docstring fix
2025-12-12 13:40:49 +05:30
David El Malih
b53bd8372b Improve docstrings and type hints in scheduling_dpmsolver_singlestep.py (#12798)
feat: add flow sigmas, dynamic shifting, and refine type hints in DPMSolverSinglestepScheduler
2025-12-12 13:40:49 +05:30
David Lacalle Castillo
a73981fe17 [PRX] Improve model compilation (#12787)
* Reimplement img2seq & seq2img in PRX to enable ONNX build without Col2Im (incompatible with TensorRT).

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2025-12-12 13:40:49 +05:30
Sayak Paul
d738ec4141 Merge branch 'main' into control-lora 2025-12-08 18:32:45 +08:00
Sayak Paul
03d1751cce Merge branch 'main' into control-lora 2025-12-05 21:55:26 +08:00
lavinal712
cd71418052 add doc 2025-12-05 10:24:20 +08:00
lavinal712
58559ecc7e no need modify as peft updated 2025-12-05 10:01:11 +08:00
Yuqian Hong
48eeeae1f7 Merge branch 'main' into control-lora 2025-12-05 09:44:22 +08:00
Yuqian Hong
2223722e5b Merge branch 'huggingface:main' into control-lora 2025-11-28 08:51:59 +08:00
lavinal712
4d1e8912d6 rename 2025-11-22 09:50:45 +08:00
Yuqian Hong
dfad05625e Merge branch 'huggingface:main' into control-lora 2025-11-22 09:23:09 +08:00
Yuqian Hong
9d94c377ef Merge branch 'huggingface:main' into control-lora 2025-09-22 11:44:52 +08:00
Yuqian Hong
1e8221ce39 Add files via upload 2025-08-20 21:23:22 +08:00
Yuqian Hong
00a26cd8dd Create control_lora.py 2025-08-20 21:23:04 +08:00
Yuqian Hong
a2eff1c668 Merge branch 'huggingface:main' into control-lora 2025-08-20 09:23:49 +08:00
Yuqian Hong
1c902725b0 Merge branch 'huggingface:main' into control-lora 2025-08-19 06:41:43 +08:00
Yuqian Hong
59a42b23d3 Merge branch 'huggingface:main' into control-lora 2025-08-17 16:32:42 +08:00
Yuqian Hong
4a64d64407 Merge branch 'main' into control-lora 2025-08-14 11:51:09 +08:00
Yuqian Hong
c6c13b6717 Merge branch 'huggingface:main' into control-lora 2025-08-09 00:24:18 +08:00
Yuqian Hong
af8255e934 Merge branch 'main' into control-lora 2025-07-30 15:48:15 +08:00
Yuqian Hong
d3a07558cf Merge branch 'main' into control-lora 2025-07-21 16:00:43 +08:00
lavinal712
23cba1804f fix alpha 2025-07-05 08:18:34 +00:00
lavinal712
53a06cc969 delete state_dict print 2025-07-05 07:52:01 +00:00
lavinal712
0a5bd74931 1 2025-07-05 05:12:57 +00:00
Yuqian Hong
d752992831 Merge branch 'huggingface:main' into control-lora 2025-07-05 09:53:39 +08:00
Yuqian Hong
39e9254208 Merge branch 'huggingface:main' into control-lora 2025-07-02 17:48:25 +08:00
lavinal712
c134bca767 change peft.py 2025-05-29 14:24:16 +00:00
lavinal712
63bafc88cd change peft.py 2025-05-29 14:23:41 +00:00
Yuqian Hong
8f7fc0ada0 Merge branch 'huggingface:main' into control-lora 2025-05-29 21:59:42 +08:00
lavinal712
6fff794e59 merged but bug 2025-04-09 07:56:40 +00:00
Yuqian Hong
ab9eeff757 Merge branch 'main' into control-lora 2025-04-09 15:41:28 +08:00
lavinal712
6a1ff82d08 resolve conflits 2025-04-09 07:41:14 +00:00
Yuqian Hong
ce2b34bba7 Merge branch 'main' into control-lora 2025-03-26 10:07:45 +08:00
Sayak Paul
2de1505e6e Merge branch 'main' into control-lora 2025-03-25 18:58:07 +01:00
lavinal712
81eed41b74 delete json print 2025-03-23 10:29:08 +00:00
lavinal712
0719c20f5e fix module_to_save bug 2025-03-23 10:27:40 +00:00
lavinal712
7c25a06591 fix PeftAdapterMixin 2025-03-23 09:36:00 +00:00
Yuqian Hong
280cf7fd38 Merge branch 'huggingface:main' into control-lora 2025-03-23 13:30:27 +08:00
Yuqian Hong
33288e667f Merge branch 'huggingface:main' into control-lora 2025-03-17 11:38:42 +08:00
Yuqian Hong
dd24464065 Merge branch 'huggingface:main' into control-lora 2025-02-23 23:14:08 +08:00
lavinal712
523967f396 1 2025-02-15 13:45:51 +00:00
lavinal712
10daac7e19 1 2025-02-15 13:41:30 +00:00
lavinal712
de61226385 1 2025-02-15 13:38:02 +00:00
lavinal712
39b3b84acc add control-lora 2025-02-07 18:43:19 +00:00
lavinal712
2453e149d2 1 2025-02-07 10:24:24 +00:00
lavinal712
9cf8ad7a73 test 2025-02-04 16:40:59 +00:00
lavinal712
e9d91e156d cannot load lora adapter 2025-02-01 01:29:19 +00:00
lavinal712
18de3adad1 run control-lora on diffusers 2025-01-30 03:00:59 +00:00
397 changed files with 6860 additions and 48392 deletions

View File

@@ -28,7 +28,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -58,7 +58,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}

View File

@@ -25,10 +25,10 @@ jobs:
if: github.event_name == 'pull_request'
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Find Changed Dockerfiles
id: file_changes
@@ -99,16 +99,16 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v3
uses: docker/login-action@v2
with:
username: ${{ env.REGISTRY }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Build and push
uses: docker/build-push-action@v6
uses: docker/build-push-action@v3
with:
no-cache: true
context: ./docker/${{ matrix.image-name }}

View File

@@ -17,10 +17,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: '3.10'

View File

@@ -1,22 +0,0 @@
---
name: CodeQL Security Analysis For Github Actions
on:
push:
branches: ["main"]
workflow_dispatch:
# pull_request:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
permissions:
security-events: write
packages: read
actions: read
contents: read
with:
languages: '["actions","python"]'
queries: 'security-extended,security-and-quality'
runner: 'ubuntu-latest' #optional if need custom runner

View File

@@ -24,6 +24,7 @@ jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-22.04
steps:
# Checkout to correct ref
@@ -38,41 +39,37 @@ jobs:
# If ref is 'refs/heads/main' => set 'main'
# Else it must be a tag => set {tag}
- name: Set checkout_ref and path_in_repo
env:
EVENT_NAME: ${{ github.event_name }}
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
GITHUB_REF: ${{ github.ref }}
run: |
if [ "$EVENT_NAME" == "workflow_dispatch" ]; then
if [ -z "$EVENT_INPUT_REF" ]; then
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
if [ -z "${{ github.event.inputs.ref }}" ]; then
echo "Error: Missing ref input"
exit 1
elif [ "$EVENT_INPUT_REF" == "main" ]; then
elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
fi
elif [ "$GITHUB_REF" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
# e.g. refs/tags/v0.28.1 -> v0.28.1
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi
- name: Print env vars
run: |
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
- uses: actions/checkout@v6
- uses: actions/checkout@v3
with:
ref: ${{ env.CHECKOUT_REF }}
# Setup + install dependencies
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -102,4 +99,4 @@ jobs:
- name: Report failure status
if: ${{ failure() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure

View File

@@ -28,7 +28,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -97,7 +97,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -119,7 +119,7 @@ jobs:
module: [models, schedulers, lora, others, single_file, examples]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -167,7 +167,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -228,7 +228,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -263,7 +263,7 @@ jobs:
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_big_gpu_test_reports
path: reports
@@ -280,7 +280,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -321,7 +321,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -355,7 +355,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -391,7 +391,7 @@ jobs:
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
@@ -408,7 +408,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -441,7 +441,7 @@ jobs:
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
@@ -466,7 +466,7 @@ jobs:
image: diffusers/diffusers-pytorch-cpu
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -474,7 +474,7 @@ jobs:
run: mkdir -p combined_reports
- name: Download all test reports
uses: actions/download-artifact@v7
uses: actions/download-artifact@v4
with:
path: artifacts
@@ -500,7 +500,7 @@ jobs:
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
- name: Upload consolidated report
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: consolidated_test_report
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
@@ -514,7 +514,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@v3
# with:
# fetch-depth: 2
#
@@ -554,7 +554,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@v4
# with:
# name: torch_mps_test_reports
# path: reports
@@ -570,7 +570,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@v3
# with:
# fetch-depth: 2
#
@@ -610,7 +610,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@v4
# with:
# name: torch_mps_test_reports
# path: reports

View File

@@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.8'

View File

@@ -18,9 +18,9 @@ jobs:
check_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -1,4 +1,3 @@
name: Fast PR tests for Modular
on:
@@ -36,9 +35,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -56,9 +55,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -75,34 +74,26 @@ jobs:
if: ${{ failure() }}
run: |
echo "Repo consistency check failed. Please ensure the right dependency versions are installed with 'pip install -e .[quality]' and run 'make fix-copies'" >> $GITHUB_STEP_SUMMARY
check_auto_docs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install --upgrade pip
pip install .[quality]
- name: Check auto docs
run: make modular-autodoctrings
- name: Check if failure
if: ${{ failure() }}
run: |
echo "Auto docstring checks failed. Please run `python utils/modular_auto_docstring.py --fix_and_overwrite`." >> $GITHUB_STEP_SUMMARY
run_fast_tests:
needs: [check_code_quality, check_repository_consistency, check_auto_docs]
name: Fast PyTorch Modular Pipeline CPU tests
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: aws-highmemory-32-plus
group: ${{ matrix.config.runner }}
container:
image: diffusers/diffusers-pytorch-cpu
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
@@ -111,7 +102,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -127,19 +118,22 @@ jobs:
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cpu_modular_pipelines \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -28,7 +28,7 @@ jobs:
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install dependencies
@@ -42,7 +42,7 @@ jobs:
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v3
with:
name: test_fetched
path: test_preparation.txt
@@ -83,7 +83,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -109,7 +109,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.modules }}_test_reports
path: reports
@@ -138,7 +138,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -164,7 +164,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -31,9 +31,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -51,9 +51,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -108,7 +108,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -153,7 +153,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
@@ -185,7 +185,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
@@ -236,7 +236,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -273,7 +273,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_main_test_reports
path: reports

View File

@@ -32,9 +32,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -52,9 +52,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -83,7 +83,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -100,7 +100,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -120,7 +120,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -170,7 +170,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -193,7 +193,7 @@ jobs:
module: [models, schedulers, lora, others]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -239,7 +239,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -255,7 +255,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -287,7 +287,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -18,9 +18,9 @@ jobs:
check_torch_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -29,7 +29,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -46,7 +46,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -66,7 +66,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -98,7 +98,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -120,7 +120,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -155,7 +155,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -172,7 +172,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -199,7 +199,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -216,7 +216,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -240,7 +240,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_xformers_test_reports
path: reports
@@ -256,7 +256,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -286,7 +286,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -54,7 +54,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -88,7 +88,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -65,7 +65,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_torch_mps_test_reports
path: reports

View File

@@ -15,10 +15,10 @@ jobs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout Repo
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.8'
@@ -40,12 +40,12 @@ jobs:
steps:
- name: Checkout Repo
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"

View File

@@ -27,7 +27,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -94,7 +94,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -116,7 +116,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -149,7 +149,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_${{ matrix.module }}_test_reports
path: reports
@@ -166,7 +166,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -205,7 +205,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -222,7 +222,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -247,7 +247,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -264,7 +264,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -288,7 +288,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_xformers_test_reports
path: reports
@@ -305,7 +305,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -336,7 +336,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -57,7 +57,7 @@ jobs:
shell: bash -e {0}
- name: Checkout PR branch
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: refs/pull/${{ inputs.pr_number }}/head

View File

@@ -27,7 +27,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2

View File

@@ -35,7 +35,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2

View File

@@ -15,10 +15,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v1
with:
python-version: 3.8

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.42.1
uses: crate-ci/typos@v1.12.4

View File

@@ -15,7 +15,7 @@ jobs:
shell: bash -l {0}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Setup environment
run: |

View File

@@ -1 +0,0 @@
docs/source/en/conceptual/contribution.md

506
CONTRIBUTING.md Normal file
View File

@@ -0,0 +1,506 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# How to contribute to Diffusers 🧨
We ❤️ contributions from the open-source community! Everyone is welcome, and all types of participation not just code are valued and appreciated. Answering questions, helping others, reaching out, and improving the documentation are all immensely valuable to the community, so don't be afraid and get involved if you're up for it!
Everyone is encouraged to start by saying 👋 in our public Discord channel. We discuss the latest trends in diffusion models, ask questions, show off personal projects, help each other with contributions, or just hang out ☕. <a href="https://discord.gg/G7tWnz98XR"><img alt="Join us on Discord" src="https://img.shields.io/discord/823813159592001537?color=5865F2&logo=Discord&logoColor=white"></a>
Whichever way you choose to contribute, we strive to be part of an open, welcoming, and kind community. Please, read our [code of conduct](https://github.com/huggingface/diffusers/blob/main/CODE_OF_CONDUCT.md) and be mindful to respect it during your interactions. We also recommend you become familiar with the [ethical guidelines](https://huggingface.co/docs/diffusers/conceptual/ethical_guidelines) that guide our project and ask you to adhere to the same principles of transparency and responsibility.
We enormously value feedback from the community, so please do not be afraid to speak up if you believe you have valuable feedback that can help improve the library - every message, comment, issue, and pull request (PR) is read and considered.
## Overview
You can contribute in many ways ranging from answering questions on issues to adding new diffusion models to
the core library.
In the following, we give an overview of different ways to contribute, ranked by difficulty in ascending order. All of them are valuable to the community.
* 1. Asking and answering questions on [the Diffusers discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers) or on [Discord](https://discord.gg/G7tWnz98XR).
* 2. Opening new issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues/new/choose).
* 3. Answering issues on [the GitHub Issues tab](https://github.com/huggingface/diffusers/issues).
* 4. Fix a simple issue, marked by the "Good first issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
* 5. Contribute to the [documentation](https://github.com/huggingface/diffusers/tree/main/docs/source).
* 6. Contribute a [Community Pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3Acommunity-examples).
* 7. Contribute to the [examples](https://github.com/huggingface/diffusers/tree/main/examples).
* 8. Fix a more difficult issue, marked by the "Good second issue" label, see [here](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22).
* 9. Add a new pipeline, model, or scheduler, see ["New Pipeline/Model"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22) and ["New scheduler"](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22) issues. For this contribution, please have a look at [Design Philosophy](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md).
As said before, **all contributions are valuable to the community**.
In the following, we will explain each contribution a bit more in detail.
For all contributions 4-9, you will need to open a PR. It is explained in detail how to do so in [Opening a pull request](#how-to-open-a-pr).
### 1. Asking and answering questions on the Diffusers discussion forum or on the Diffusers Discord
Any question or comment related to the Diffusers library can be asked on the [discussion forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/) or on [Discord](https://discord.gg/G7tWnz98XR). Such questions and comments include (but are not limited to):
- Reports of training or inference experiments in an attempt to share knowledge
- Presentation of personal projects
- Questions to non-official training examples
- Project proposals
- General feedback
- Paper summaries
- Asking for help on personal projects that build on top of the Diffusers library
- General questions
- Ethical questions regarding diffusion models
- ...
Every question that is asked on the forum or on Discord actively encourages the community to publicly
share knowledge and might very well help a beginner in the future who has the same question you're
having. Please do pose any questions you might have.
In the same spirit, you are of immense help to the community by answering such questions because this way you are publicly documenting knowledge for everybody to learn from.
**Please** keep in mind that the more effort you put into asking or answering a question, the higher
the quality of the publicly documented knowledge. In the same way, well-posed and well-answered questions create a high-quality knowledge database accessible to everybody, while badly posed questions or answers reduce the overall quality of the public knowledge database.
In short, a high quality question or answer is *precise*, *concise*, *relevant*, *easy-to-understand*, *accessible*, and *well-formatted/well-posed*. For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
**NOTE about channels**:
[*The forum*](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) is much better indexed by search engines, such as Google. Posts are ranked by popularity rather than chronologically. Hence, it's easier to look up questions and answers that we posted some time ago.
In addition, questions and answers posted in the forum can easily be linked to.
In contrast, *Discord* has a chat-like format that invites fast back-and-forth communication.
While it will most likely take less time for you to get an answer to your question on Discord, your
question won't be visible anymore over time. Also, it's much harder to find information that was posted a while back on Discord. We therefore strongly recommend using the forum for high-quality questions and answers in an attempt to create long-lasting knowledge for the community. If discussions on Discord lead to very interesting answers and conclusions, we recommend posting the results on the forum to make the information more available for future readers.
### 2. Opening new issues on the GitHub issues tab
The 🧨 Diffusers library is robust and reliable thanks to the users who notify us of
the problems they encounter. So thank you for reporting an issue.
Remember, GitHub issues are reserved for technical questions directly related to the Diffusers library, bug reports, feature requests, or feedback on the library design.
In a nutshell, this means that everything that is **not** related to the **code of the Diffusers library** (including the documentation) should **not** be asked on GitHub, but rather on either the [forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
**Please consider the following guidelines when opening a new issue**:
- Make sure you have searched whether your issue has already been asked before (use the search bar on GitHub under Issues).
- Please never report a new issue on another (related) issue. If another issue is highly related, please
open a new issue nevertheless and link to the related issue.
- Make sure your issue is written in English. Please use one of the great, free online translation services, such as [DeepL](https://www.deepl.com/translator) to translate from your native language to English if you are not comfortable in English.
- Check whether your issue might be solved by updating to the newest Diffusers version. Before posting your issue, please make sure that `python -c "import diffusers; print(diffusers.__version__)"` is higher or matches the latest Diffusers version.
- Remember that the more effort you put into opening a new issue, the higher the quality of your answer will be and the better the overall quality of the Diffusers issues.
New issues usually include the following.
#### 2.1. Reproducible, minimal bug reports
A bug report should always have a reproducible code snippet and be as minimal and concise as possible.
This means in more detail:
- Narrow the bug down as much as you can, **do not just dump your whole code file**.
- Format your code.
- Do not include any external libraries except for Diffusers depending on them.
- **Always** provide all necessary information about your environment; for this, you can run: `diffusers-cli env` in your shell and copy-paste the displayed information to the issue.
- Explain the issue. If the reader doesn't know what the issue is and why it is an issue, she cannot solve it.
- **Always** make sure the reader can reproduce your issue with as little effort as possible. If your code snippet cannot be run because of missing libraries or undefined variables, the reader cannot help you. Make sure your reproducible code snippet is as minimal as possible and can be copy-pasted into a simple Python shell.
- If in order to reproduce your issue a model and/or dataset is required, make sure the reader has access to that model or dataset. You can always upload your model or dataset to the [Hub](https://huggingface.co) to make it easily downloadable. Try to keep your model and dataset as small as possible, to make the reproduction of your issue as effortless as possible.
For more information, please have a look through the [How to write a good issue](#how-to-write-a-good-issue) section.
You can open a bug report [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&projects=&template=bug-report.yml).
#### 2.2. Feature requests
A world-class feature request addresses the following points:
1. Motivation first:
* Is it related to a problem/frustration with the library? If so, please explain
why. Providing a code snippet that demonstrates the problem is best.
* Is it related to something you would need for a project? We'd love to hear
about it!
* Is it something you worked on and think could benefit the community?
Awesome! Tell us what problem it solved for you.
2. Write a *full paragraph* describing the feature;
3. Provide a **code snippet** that demonstrates its future use;
4. In case this is related to a paper, please attach a link;
5. Attach any additional information (drawings, screenshots, etc.) you think may help.
You can open a feature request [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feature_request.md&title=).
#### 2.3 Feedback
Feedback about the library design and why it is good or not good helps the core maintainers immensely to build a user-friendly library. To understand the philosophy behind the current design philosophy, please have a look [here](https://huggingface.co/docs/diffusers/conceptual/philosophy). If you feel like a certain design choice does not fit with the current design philosophy, please explain why and how it should be changed. If a certain design choice follows the design philosophy too much, hence restricting use cases, explain why and how it should be changed.
If a certain design choice is very useful for you, please also leave a note as this is great feedback for future design decisions.
You can open an issue about feedback [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=).
#### 2.4 Technical questions
Technical questions are mainly about why certain code of the library was written in a certain way, or what a certain part of the code does. Please make sure to link to the code in question and please provide detail on
why this part of the code is difficult to understand.
You can open an issue about a technical question [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=bug&template=bug-report.yml).
#### 2.5 Proposal to add a new model, scheduler, or pipeline
If the diffusion model community released a new model, pipeline, or scheduler that you would like to see in the Diffusers library, please provide the following information:
* Short description of the diffusion pipeline, model, or scheduler and link to the paper or public release.
* Link to any of its open-source implementation.
* Link to the model weights if they are available.
If you are willing to contribute to the model yourself, let us know so we can best guide you. Also, don't forget
to tag the original author of the component (model, scheduler, pipeline, etc.) by GitHub handle if you can find it.
You can open a request for a model/pipeline/scheduler [here](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=New+model%2Fpipeline%2Fscheduler&template=new-model-addition.yml).
### 3. Answering issues on the GitHub issues tab
Answering issues on GitHub might require some technical knowledge of Diffusers, but we encourage everybody to give it a try even if you are not 100% certain that your answer is correct.
Some tips to give a high-quality answer to an issue:
- Be as concise and minimal as possible.
- Stay on topic. An answer to the issue should concern the issue and only the issue.
- Provide links to code, papers, or other sources that prove or encourage your point.
- Answer in code. If a simple code snippet is the answer to the issue or shows how the issue can be solved, please provide a fully reproducible code snippet.
Also, many issues tend to be simply off-topic, duplicates of other issues, or irrelevant. It is of great
help to the maintainers if you can answer such issues, encouraging the author of the issue to be
more precise, provide the link to a duplicated issue or redirect them to [the forum](https://discuss.huggingface.co/c/discussion-related-to-httpsgithubcomhuggingfacediffusers/63) or [Discord](https://discord.gg/G7tWnz98XR).
If you have verified that the issued bug report is correct and requires a correction in the source code,
please have a look at the next sections.
For all of the following contributions, you will need to open a PR. It is explained in detail how to do so in the [Opening a pull request](#how-to-open-a-pr) section.
### 4. Fixing a "Good first issue"
*Good first issues* are marked by the [Good first issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22) label. Usually, the issue already
explains how a potential solution should look so that it is easier to fix.
If the issue hasn't been closed and you would like to try to fix this issue, you can just leave a message "I would like to try this issue.". There are usually three scenarios:
- a.) The issue description already proposes a fix. In this case and if the solution makes sense to you, you can open a PR or draft PR to fix it.
- b.) The issue description does not propose a fix. In this case, you can ask what a proposed fix could look like and someone from the Diffusers team should answer shortly. If you have a good idea of how to fix it, feel free to directly open a PR.
- c.) There is already an open PR to fix the issue, but the issue hasn't been closed yet. If the PR has gone stale, you can simply open a new PR and link to the stale PR. PRs often go stale if the original contributor who wanted to fix the issue suddenly cannot find the time anymore to proceed. This often happens in open-source and is very normal. In this case, the community will be very happy if you give it a new try and leverage the knowledge of the existing PR. If there is already a PR and it is active, you can help the author by giving suggestions, reviewing the PR or even asking whether you can contribute to the PR.
### 5. Contribute to the documentation
A good library **always** has good documentation! The official documentation is often one of the first points of contact for new users of the library, and therefore contributing to the documentation is a **highly
valuable contribution**.
Contributing to the library can have many forms:
- Correcting spelling or grammatical errors.
- Correct incorrect formatting of the docstring. If you see that the official documentation is weirdly displayed or a link is broken, we are very happy if you take some time to correct it.
- Correct the shape or dimensions of a docstring input or output tensor.
- Clarify documentation that is hard to understand or incorrect.
- Update outdated code examples.
- Translating the documentation to another language.
Anything displayed on [the official Diffusers doc page](https://huggingface.co/docs/diffusers/index) is part of the official documentation and can be corrected, adjusted in the respective [documentation source](https://github.com/huggingface/diffusers/tree/main/docs/source).
Please have a look at [this page](https://github.com/huggingface/diffusers/tree/main/docs) on how to verify changes made to the documentation locally.
### 6. Contribute a community pipeline
[Pipelines](https://huggingface.co/docs/diffusers/api/pipelines/overview) are usually the first point of contact between the Diffusers library and the user.
Pipelines are examples of how to use Diffusers [models](https://huggingface.co/docs/diffusers/api/models/overview) and [schedulers](https://huggingface.co/docs/diffusers/api/schedulers/overview).
We support two types of pipelines:
- Official Pipelines
- Community Pipelines
Both official and community pipelines follow the same design and consist of the same type of components.
Official pipelines are tested and maintained by the core maintainers of Diffusers. Their code
resides in [src/diffusers/pipelines](https://github.com/huggingface/diffusers/tree/main/src/diffusers/pipelines).
In contrast, community pipelines are contributed and maintained purely by the **community** and are **not** tested.
They reside in [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and while they can be accessed via the [PyPI diffusers package](https://pypi.org/project/diffusers/), their code is not part of the PyPI distribution.
The reason for the distinction is that the core maintainers of the Diffusers library cannot maintain and test all
possible ways diffusion models can be used for inference, but some of them may be of interest to the community.
Officially released diffusion pipelines,
such as Stable Diffusion are added to the core src/diffusers/pipelines package which ensures
high quality of maintenance, no backward-breaking code changes, and testing.
More bleeding edge pipelines should be added as community pipelines. If usage for a community pipeline is high, the pipeline can be moved to the official pipelines upon request from the community. This is one of the ways we strive to be a community-driven library.
To add a community pipeline, one should add a <name-of-the-community>.py file to [examples/community](https://github.com/huggingface/diffusers/tree/main/examples/community) and adapt the [examples/community/README.md](https://github.com/huggingface/diffusers/tree/main/examples/community/README.md) to include an example of the new pipeline.
An example can be seen [here](https://github.com/huggingface/diffusers/pull/2400).
Community pipeline PRs are only checked at a superficial level and ideally they should be maintained by their original authors.
Contributing a community pipeline is a great way to understand how Diffusers models and schedulers work. Having contributed a community pipeline is usually the first stepping stone to contributing an official pipeline to the
core package.
### 7. Contribute to training examples
Diffusers examples are a collection of training scripts that reside in [examples](https://github.com/huggingface/diffusers/tree/main/examples).
We support two types of training examples:
- Official training examples
- Research training examples
Research training examples are located in [examples/research_projects](https://github.com/huggingface/diffusers/tree/main/examples/research_projects) whereas official training examples include all folders under [examples](https://github.com/huggingface/diffusers/tree/main/examples) except the `research_projects` and `community` folders.
The official training examples are maintained by the Diffusers' core maintainers whereas the research training examples are maintained by the community.
This is because of the same reasons put forward in [6. Contribute a community pipeline](#6-contribute-a-community-pipeline) for official pipelines vs. community pipelines: It is not feasible for the core maintainers to maintain all possible training methods for diffusion models.
If the Diffusers core maintainers and the community consider a certain training paradigm to be too experimental or not popular enough, the corresponding training code should be put in the `research_projects` folder and maintained by the author.
Both official training and research examples consist of a directory that contains one or more training scripts, a `requirements.txt` file, and a `README.md` file. In order for the user to make use of the
training examples, it is required to clone the repository:
```bash
git clone https://github.com/huggingface/diffusers
```
as well as to install all additional dependencies required for training:
```bash
cd diffusers
pip install -r examples/<your-example-folder>/requirements.txt
```
Therefore when adding an example, the `requirements.txt` file shall define all pip dependencies required for your training example so that once all those are installed, the user can run the example's training script. See, for example, the [DreamBooth `requirements.txt` file](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/requirements.txt).
Training examples of the Diffusers library should adhere to the following philosophy:
- All the code necessary to run the examples should be found in a single Python file.
- One should be able to run the example from the command line with `python <your-example>.py --args`.
- Examples should be kept simple and serve as **an example** on how to use Diffusers for training. The purpose of example scripts is **not** to create state-of-the-art diffusion models, but rather to reproduce known training schemes without adding too much custom logic. As a byproduct of this point, our examples also strive to serve as good educational materials.
To contribute an example, it is highly recommended to look at already existing examples such as [dreambooth](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth.py) to get an idea of how they should look like.
We strongly advise contributors to make use of the [Accelerate library](https://github.com/huggingface/accelerate) as it's tightly integrated
with Diffusers.
Once an example script works, please make sure to add a comprehensive `README.md` that states how to use the example exactly. This README should include:
- An example command on how to run the example script as shown [here e.g.](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#running-locally-with-pytorch).
- A link to some training results (logs, models, ...) that show what the user can expect as shown [here e.g.](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
- If you are adding a non-official/research training example, **please don't forget** to add a sentence that you are maintaining this training example which includes your git handle as shown [here](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/intel_opts#diffusers-examples-with-intel-optimizations).
If you are contributing to the official training examples, please also make sure to add a test to [examples/test_examples.py](https://github.com/huggingface/diffusers/blob/main/examples/test_examples.py). This is not necessary for non-official training examples.
### 8. Fixing a "Good second issue"
*Good second issues* are marked by the [Good second issue](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22Good+second+issue%22) label. Good second issues are
usually more complicated to solve than [Good first issues](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22good+first+issue%22).
The issue description usually gives less guidance on how to fix the issue and requires
a decent understanding of the library by the interested contributor.
If you are interested in tackling a good second issue, feel free to open a PR to fix it and link the PR to the issue. If you see that a PR has already been opened for this issue but did not get merged, have a look to understand why it wasn't merged and try to open an improved PR.
Good second issues are usually more difficult to get merged compared to good first issues, so don't hesitate to ask for help from the core maintainers. If your PR is almost finished the core maintainers can also jump into your PR and commit to it in order to get it merged.
### 9. Adding pipelines, models, schedulers
Pipelines, models, and schedulers are the most important pieces of the Diffusers library.
They provide easy access to state-of-the-art diffusion technologies and thus allow the community to
build powerful generative AI applications.
By adding a new model, pipeline, or scheduler you might enable a new powerful use case for any of the user interfaces relying on Diffusers which can be of immense value for the whole generative AI ecosystem.
Diffusers has a couple of open feature requests for all three components - feel free to gloss over them
if you don't know yet what specific component you would like to add:
- [Model or pipeline](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+pipeline%2Fmodel%22)
- [Scheduler](https://github.com/huggingface/diffusers/issues?q=is%3Aopen+is%3Aissue+label%3A%22New+scheduler%22)
Before adding any of the three components, it is strongly recommended that you give the [Philosophy guide](https://github.com/huggingface/diffusers/blob/main/PHILOSOPHY.md) a read to better understand the design of any of the three components. Please be aware that
we cannot merge model, scheduler, or pipeline additions that strongly diverge from our design philosophy
as it will lead to API inconsistencies. If you fundamentally disagree with a design choice, please
open a [Feedback issue](https://github.com/huggingface/diffusers/issues/new?assignees=&labels=&template=feedback.md&title=) instead so that it can be discussed whether a certain design
pattern/design choice shall be changed everywhere in the library and whether we shall update our design philosophy. Consistency across the library is very important for us.
Please make sure to add links to the original codebase/paper to the PR and ideally also ping the
original author directly on the PR so that they can follow the progress and potentially help with questions.
If you are unsure or stuck in the PR, don't hesitate to leave a message to ask for a first review or help.
## How to write a good issue
**The better your issue is written, the higher the chances that it will be quickly resolved.**
1. Make sure that you've used the correct template for your issue. You can pick between *Bug Report*, *Feature Request*, *Feedback about API Design*, *New model/pipeline/scheduler addition*, *Forum*, or a blank issue. Make sure to pick the correct one when opening [a new issue](https://github.com/huggingface/diffusers/issues/new/choose).
2. **Be precise**: Give your issue a fitting title. Try to formulate your issue description as simple as possible. The more precise you are when submitting an issue, the less time it takes to understand the issue and potentially solve it. Make sure to open an issue for one issue only and not for multiple issues. If you found multiple issues, simply open multiple issues. If your issue is a bug, try to be as precise as possible about what bug it is - you should not just write "Error in diffusers".
3. **Reproducibility**: No reproducible code snippet == no solution. If you encounter a bug, maintainers **have to be able to reproduce** it. Make sure that you include a code snippet that can be copy-pasted into a Python interpreter to reproduce the issue. Make sure that your code snippet works, *i.e.* that there are no missing imports or missing links to images, ... Your issue should contain an error message **and** a code snippet that can be copy-pasted without any changes to reproduce the exact same error message. If your issue is using local model weights or local data that cannot be accessed by the reader, the issue cannot be solved. If you cannot share your data or model, try to make a dummy model or dummy data.
4. **Minimalistic**: Try to help the reader as much as you can to understand the issue as quickly as possible by staying as concise as possible. Remove all code / all information that is irrelevant to the issue. If you have found a bug, try to create the easiest code example you can to demonstrate your issue, do not just dump your whole workflow into the issue as soon as you have found a bug. E.g., if you train a model and get an error at some point during the training, you should first try to understand what part of the training code is responsible for the error and try to reproduce it with a couple of lines. Try to use dummy data instead of full datasets.
5. Add links. If you are referring to a certain naming, method, or model make sure to provide a link so that the reader can better understand what you mean. If you are referring to a specific PR or issue, make sure to link it to your issue. Do not assume that the reader knows what you are talking about. The more links you add to your issue the better.
6. Formatting. Make sure to nicely format your issue by formatting code into Python code syntax, and error messages into normal code syntax. See the [official GitHub formatting docs](https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax) for more information.
7. Think of your issue not as a ticket to be solved, but rather as a beautiful entry to a well-written encyclopedia. Every added issue is a contribution to publicly available knowledge. By adding a nicely written issue you not only make it easier for maintainers to solve your issue, but you are helping the whole community to better understand a certain aspect of the library.
## How to write a good PR
1. Be a chameleon. Understand existing design patterns and syntax and make sure your code additions flow seamlessly into the existing code base. Pull requests that significantly diverge from existing design patterns or user interfaces will not be merged.
2. Be laser focused. A pull request should solve one problem and one problem only. Make sure to not fall into the trap of "also fixing another problem while we're adding it". It is much more difficult to review pull requests that solve multiple, unrelated problems at once.
3. If helpful, try to add a code snippet that displays an example of how your addition can be used.
4. The title of your pull request should be a summary of its contribution.
5. If your pull request addresses an issue, please mention the issue number in
the pull request description to make sure they are linked (and people
consulting the issue know you are working on it);
6. To indicate a work in progress please prefix the title with `[WIP]`. These
are useful to avoid duplicated work, and to differentiate it from PRs ready
to be merged;
7. Try to formulate and format your text as explained in [How to write a good issue](#how-to-write-a-good-issue).
8. Make sure existing tests pass;
9. Add high-coverage tests. No quality testing = no merge.
- If you are adding new `@slow` tests, make sure they pass using
`RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
CircleCI does not run the slow tests, but GitHub Actions does every night!
10. All public methods must have informative docstrings that work nicely with markdown. See [`pipeline_latent_diffusion.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py) for an example.
11. Due to the rapidly growing repository, it is important to make sure that no files that would significantly weigh down the repository are added. This includes images, videos, and other non-text files. We prefer to leverage a hf.co hosted `dataset` like
[`hf-internal-testing`](https://huggingface.co/hf-internal-testing) or [huggingface/documentation-images](https://huggingface.co/datasets/huggingface/documentation-images) to place these files.
If an external contribution, feel free to add the images to your PR and ask a Hugging Face member to migrate your images
to this dataset.
## How to open a PR
Before writing code, we strongly advise you to search through the existing PRs or
issues to make sure that nobody is already working on the same thing. If you are
unsure, it is always a good idea to open an issue to get some feedback.
You will need basic `git` proficiency to be able to contribute to
🧨 Diffusers. `git` is not the easiest tool to use but it has the greatest
manual. Type `git --help` in a shell and enjoy. If you prefer books, [Pro
Git](https://git-scm.com/book/en/v2) is a very good reference.
Follow these steps to start contributing ([supported Python versions](https://github.com/huggingface/diffusers/blob/42f25d601a910dceadaee6c44345896b4cfa9928/setup.py#L270)):
1. Fork the [repository](https://github.com/huggingface/diffusers) by
clicking on the 'Fork' button on the repository's page. This creates a copy of the code
under your GitHub user account.
2. Clone your fork to your local disk, and add the base repository as a remote:
```bash
$ git clone git@github.com:<your GitHub handle>/diffusers.git
$ cd diffusers
$ git remote add upstream https://github.com/huggingface/diffusers.git
```
3. Create a new branch to hold your development changes:
```bash
$ git checkout -b a-descriptive-name-for-my-changes
```
**Do not** work on the `main` branch.
4. Set up a development environment by running the following command in a virtual environment:
```bash
$ pip install -e ".[dev]"
```
If you have already cloned the repo, you might need to `git pull` to get the most recent changes in the
library.
5. Develop the features on your branch.
As you work on the features, you should make sure that the test suite
passes. You should run the tests impacted by your changes like this:
```bash
$ pytest tests/<TEST_TO_RUN>.py
```
Before you run the tests, please make sure you install the dependencies required for testing. You can do so
with this command:
```bash
$ pip install -e ".[test]"
```
You can also run the full test suite with the following command, but it takes
a beefy machine to produce a result in a decent amount of time now that
Diffusers has grown a lot. Here is the command for it:
```bash
$ make test
```
🧨 Diffusers relies on `ruff` and `isort` to format its source code
consistently. After you make changes, apply automatic style corrections and code verifications
that can't be automated in one go with:
```bash
$ make style
```
🧨 Diffusers also uses `ruff` and a few custom scripts to check for coding mistakes. Quality
control runs in CI, however, you can also run the same checks with:
```bash
$ make quality
```
Once you're happy with your changes, add changed files using `git add` and
make a commit with `git commit` to record your changes locally:
```bash
$ git add modified_file.py
$ git commit -m "A descriptive message about your changes."
```
It is a good idea to sync your copy of the code with the original
repository regularly. This way you can quickly account for changes:
```bash
$ git pull upstream main
```
Push the changes to your account using:
```bash
$ git push -u origin a-descriptive-name-for-my-changes
```
6. Once you are satisfied, go to the
webpage of your fork on GitHub. Click on 'Pull request' to send your changes
to the project maintainers for review.
7. It's ok if maintainers ask you for changes. It happens to core contributors
too! So everyone can see the changes in the Pull request, work in your local
branch and push the changes to your fork. They will automatically appear in
the pull request.
### Tests
An extensive test suite is included to test the library behavior and several examples. Library tests can be found in
the [tests folder](https://github.com/huggingface/diffusers/tree/main/tests).
We like `pytest` and `pytest-xdist` because it's faster. From the root of the
repository, here's how to run tests with `pytest` for the library:
```bash
$ python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
In fact, that's how `make test` is implemented!
You can specify a smaller set of tests in order to test only the feature
you're working on.
By default, slow tests are skipped. Set the `RUN_SLOW` environment variable to
`yes` to run them. This will download many gigabytes of models — make sure you
have enough disk space and a good Internet connection, or a lot of patience!
```bash
$ RUN_SLOW=yes python -m pytest -n auto --dist=loadfile -s -v ./tests/
```
`unittest` is fully supported, here's how to run tests with it:
```bash
$ python -m unittest discover -s tests -t . -v
$ python -m unittest discover -s examples -t examples -v
```
### Syncing forked main with upstream (HuggingFace) main
To avoid pinging the upstream repository which adds reference notes to each upstream PR and sends unnecessary notifications to the developers involved in these PRs,
when syncing the main branch of a forked repository, please, follow these steps:
1. When possible, avoid syncing with the upstream using a branch and PR on the forked repository. Instead, merge directly into the forked main.
2. If a PR is absolutely necessary, use the following steps after checking out your branch:
```bash
$ git checkout -b your-branch-for-syncing
$ git pull --squash --no-commit upstream main
$ git commit -m '<your message without GitHub references>'
$ git push --set-upstream origin your-branch-for-syncing
```
### Style guide
For documentation strings, 🧨 Diffusers follows the [Google style](https://google.github.io/styleguide/pyguide.html).

View File

@@ -70,10 +70,6 @@ fix-copies:
python utils/check_copies.py --fix_and_overwrite
python utils/check_dummies.py --fix_and_overwrite
# Auto docstrings in modular blocks
modular-autodoctrings:
python utils/modular_auto_docstring.py
# Run tests for the library
test:

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.11
ARG PYTHON_VERSION=3.12
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,12 +32,10 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
torchaudio
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -2,7 +2,7 @@ FROM nvidia/cuda:12.1.0-runtime-ubuntu20.04
LABEL maintainer="Hugging Face"
LABEL repository="diffusers"
ARG PYTHON_VERSION=3.11
ARG PYTHON_VERSION=3.12
ENV DEBIAN_FRONTEND=noninteractive
RUN apt-get -y update \
@@ -32,12 +32,10 @@ RUN uv venv --python ${PYTHON_VERSION} --seed ${VIRTUAL_ENV}
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
# pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
# Install torch, torchvision, and torchaudio together to ensure compatibility
RUN uv pip install --no-cache-dir \
torch \
torchvision \
torchaudio \
--index-url https://download.pytorch.org/whl/cu121
torchaudio
RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.git@main#egg=diffusers[test]"

View File

@@ -54,8 +54,6 @@
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- local: hybrid_inference/overview
title: Remote inference
title: Inference
- isExpanded: false
sections:
@@ -90,6 +88,17 @@
title: FreeU
title: Community optimizations
title: Inference optimization
- isExpanded: false
sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
- isExpanded: false
sections:
- local: modular_diffusers/overview
@@ -261,8 +270,6 @@
title: Outputs
- local: api/quantization
title: Quantization
- local: hybrid_inference/api_reference
title: Remote inference
- local: api/parallel
title: Parallel inference
title: Main Classes
@@ -346,8 +353,6 @@
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -360,10 +365,6 @@
title: HunyuanVideoTransformer3DModel
- local: api/models/latte_transformer3d
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -401,7 +402,7 @@
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- local: api/models/z_image_transformer2d
title: ZImageTransformer2DModel
title: ZImageTransformer2DModel
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
@@ -440,10 +441,6 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -496,8 +493,6 @@
title: Bria 3.2
- local: api/pipelines/bria_fibo
title: Bria Fibo
- local: api/pipelines/bria_fibo_edit
title: Bria Fibo Edit
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
@@ -544,8 +539,6 @@
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/glm_image
title: GLM-Image
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
@@ -570,8 +563,6 @@
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/longcat_image
title: LongCat-Image
- local: api/pipelines/lumina2
title: Lumina 2.0
- local: api/pipelines/lumina
@@ -683,8 +674,6 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx2
title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi

View File

@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] apply_faster_cache
## FirstBlockCacheConfig
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig

View File

@@ -33,7 +33,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -63,10 +62,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## LTX2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin

View File

@@ -1,29 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Audio
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Audio
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Audio
[[autodoc]] AutoencoderKLLTX2Audio
- encode
- decode
- all

View File

@@ -1,29 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Video
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Video
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Video
[[autodoc]] AutoencoderKLLTX2Video
- decode
- encode
- all

View File

@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
## FluxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
[[autodoc]] models.controlnet_flux.FluxControlNetOutput

View File

@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
## SparseControlNetOutput
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput

View File

@@ -1,18 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
## GlmImageTransformer2DModel
[[autodoc]] GlmImageTransformer2DModel

View File

@@ -1,25 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LongCatImageTransformer2DModel
The model can be loaded with the following code snippet.
```python
from diffusers import LongCatImageTransformer2DModel
transformer = LongCatImageTransformer2DModel.from_pretrained("meituan-longcat/LongCat-Image ", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## LongCatImageTransformer2DModel
[[autodoc]] LongCatImageTransformer2DModel

View File

@@ -1,26 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTX2VideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTX2VideoTransformer3DModel
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTX2VideoTransformer3DModel
[[autodoc]] LTX2VideoTransformer3DModel

View File

@@ -1,33 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Bria Fibo Edit
Fibo Edit is an 8B parameter image-to-image model that introduces a new paradigm of structured control, operating on JSON inputs paired with source images to enable deterministic and repeatable editing workflows.
Featuring native masking for granular precision, it moves beyond simple prompt-based diffusion to offer explicit, interpretable control optimized for production environments.
Its lightweight architecture is designed for deep customization, empowering researchers to build specialized "Edit" models for domain-specific tasks while delivering top-tier aesthetic quality
## Usage
_As the model is gated, before using it with diffusers you first need to go to the [Bria Fibo Hugging Face page](https://huggingface.co/briaai/Fibo-Edit), fill in the form and accept the gate. Once you are in, you need to login so that your system knows youve accepted the gate._
Use the command below to log in:
```bash
hf auth login
```
## BriaFiboEditPipeline
[[autodoc]] BriaFiboEditPipeline
- all
- __call__

View File

@@ -99,9 +99,3 @@ image.save("chroma-single-file.png")
[[autodoc]] ChromaImg2ImgPipeline
- all
- __call__
## ChromaInpaintPipeline
[[autodoc]] ChromaInpaintPipeline
- all
- __call__

View File

@@ -30,10 +30,6 @@
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
Available Models/LoRAs:
- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)
- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)
- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)
### Image Editing
@@ -104,7 +100,6 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
@@ -114,8 +109,9 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0)
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
@@ -149,57 +145,6 @@ export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
### Inference with Multiple LoRAs
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush")
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
image = load_image(
"https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"Turn the pencil sketch in the image into an actual object that is consistent with the images content. The user wants to change the sketch to a crown and a hat."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=8,
guidance_scale=1.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png")
```
## ChronoEditPipeline
[[autodoc]] ChronoEditPipeline

View File

@@ -70,12 +70,6 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -21,7 +21,7 @@ The abstract from the paper is:
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️

View File

@@ -35,11 +35,5 @@ The [official implementation](https://github.com/black-forest-labs/flux2/blob/5a
## Flux2Pipeline
[[autodoc]] Flux2Pipeline
- all
- __call__
## Flux2KleinPipeline
[[autodoc]] Flux2KleinPipeline
- all
- __call__

View File

@@ -1,95 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# GLM-Image
## Overview
GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
Model architecture: a hybrid autoregressive + diffusion decoder design、
+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K4K tokens, corresponding to 1K2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
GLM-Image supports both text-to-image and image-to-image generation within a single model
+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
## Usage examples
### Text to Image Generation
```python
import torch
from diffusers.pipelines.glm_image import GlmImagePipeline
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
image = pipe(
prompt=prompt,
height=32 * 32,
width=36 * 32,
num_inference_steps=30,
guidance_scale=1.5,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
image.save("output_t2i.png")
```
### Image to Image Generation
```python
import torch
from diffusers.pipelines.glm_image import GlmImagePipeline
from PIL import Image
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
image_path = "cond.jpg"
prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
image = Image.open(image_path).convert("RGB")
image = pipe(
prompt=prompt,
image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
height=33 * 32,
width=32 * 32,
num_inference_steps=30,
guidance_scale=1.5,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
image.save("output_i2i.png")
```
+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
## GlmImagePipeline
[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
- all
- __call__
## GlmImagePipelineOutput
[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput

View File

@@ -1,114 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LongCat-Image
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
We introduce LongCat-Image, a pioneering open-source and bilingual (Chinese-English) foundation model for image generation, designed to address core challenges in multilingual text rendering, photorealism, deployment efficiency, and developer accessibility prevalent in current leading models.
### Key Features
- 🌟 **Exceptional Efficiency and Performance**: With only **6B parameters**, LongCat-Image surpasses numerous open-source models that are several times larger across multiple benchmarks, demonstrating the immense potential of efficient model design.
- 🌟 **Superior Editing Performance**: LongCat-Image-Edit model achieves state-of-the-art performance among open-source models, delivering leading instruction-following and image quality with superior visual consistency.
- 🌟 **Powerful Chinese Text Rendering**: LongCat-Image demonstrates superior accuracy and stability in rendering common Chinese characters compared to existing SOTA open-source models and achieves industry-leading coverage of the Chinese dictionary.
- 🌟 **Remarkable Photorealism**: Through an innovative data strategy and training framework, LongCat-Image achieves remarkable photorealism in generated images.
- 🌟 **Comprehensive Open-Source Ecosystem**: We provide a complete toolchain, from intermediate checkpoints to full training code, significantly lowering the barrier for further research and development.
For more details, please refer to the comprehensive [***LongCat-Image Technical Report***](https://arxiv.org/abs/2412.11963)
## Usage Example
```py
import torch
import diffusers
from diffusers import LongCatImagePipeline
weight_dtype = torch.bfloat16
pipe = LongCatImagePipeline.from_pretrained("meituan-longcat/LongCat-Image", torch_dtype=torch.bfloat16 )
pipe.to('cuda')
# pipe.enable_model_cpu_offload()
prompt = '一个年轻的亚裔女性,身穿黄色针织衫,搭配白色项链。她的双手放在膝盖上,表情恬静。背景是一堵粗糙的砖墙,午后的阳光温暖地洒在她身上,营造出一种宁静而温馨的氛围。镜头采用中距离视角,突出她的神态和服饰的细节。光线柔和地打在她的脸上,强调她的五官和饰品的质感,增加画面的层次感与亲和力。整个画面构图简洁,砖墙的纹理与阳光的光影效果相得益彰,突显出人物的优雅与从容。'
image = pipe(
prompt,
height=768,
width=1344,
guidance_scale=4.0,
num_inference_steps=50,
num_images_per_prompt=1,
generator=torch.Generator("cpu").manual_seed(43),
enable_cfg_renorm=True,
enable_prompt_rewrite=True,
).images[0]
image.save(f'./longcat_image_t2i_example.png')
```
This pipeline was contributed by LongCat-Image Team. The original codebase can be found [here](https://github.com/meituan-longcat/LongCat-Image).
Available models:
<div style="overflow-x: auto; margin-bottom: 16px;">
<table style="border-collapse: collapse; width: 100%;">
<thead>
<tr>
<th style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Models</th>
<th style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Type</th>
<th style="padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Description</th>
<th style="padding: 8px; border: 1px solid #d0d7de; background-color: #f6f8fa;">Download Link</th>
</tr>
</thead>
<tbody>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Text&#8209;to&#8209;Image</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Final Release. The standard model for out&#8209;of&#8209;the&#8209;box inference.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image">Huggingface</a></span>
</td>
</tr>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image&#8209;Dev</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Text&#8209;to&#8209;Image</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Development. Mid-training checkpoint, suitable for fine-tuning.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image-Dev">Huggingface</a></span>
</td>
</tr>
<tr>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">LongCat&#8209;Image&#8209;Edit</td>
<td style="white-space: nowrap; padding: 8px; border: 1px solid #d0d7de;">Image Editing</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">Specialized model for image editing.</td>
<td style="padding: 8px; border: 1px solid #d0d7de;">
<span style="white-space: nowrap;">🤗&nbsp;<a href="https://huggingface.co/meituan-longcat/LongCat-Image-Edit">Huggingface</a></span>
</td>
</tr>
</tbody>
</table>
</div>
## LongCatImagePipeline
[[autodoc]] LongCatImagePipeline
- all
- __call__
## LongCatImagePipelineOutput
[[autodoc]] pipelines.longcat_image.pipeline_output.LongCatImagePipelineOutput

View File

@@ -1,47 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX-2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
- all
- __call__
## LTX2ImageToVideoPipeline
[[autodoc]] LTX2ImageToVideoPipeline
- all
- __call__
## LTX2LatentUpsamplePipeline
[[autodoc]] LTX2LatentUpsamplePipeline
- all
- __call__
## LTX2PipelineOutput
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput

View File

@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
<details>
<summary>Show example code</summary>
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
@@ -474,12 +474,6 @@ export_to_video(video, "output.mp4", fps=24)
</details>
## LTXI2VLongMultiPromptPipeline
[[autodoc]] LTXI2VLongMultiPromptPipeline
- all
- __call__
## LTXPipeline
[[autodoc]] LTXPipeline

View File

@@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png")
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
```py
```
import torch
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
@@ -108,46 +108,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```
## Performance
### torch.compile
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
```python
import torch
from diffusers import QwenImagePipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer = torch.compile(pipe.transformer)
# First call triggers compilation (~7s overhead)
# Subsequent calls run at ~2.4x faster
image = pipe("a cat", num_inference_steps=50).images[0]
```
### Batched Inference with Variable-Length Prompts
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
```python
# CFG with different prompt lengths works correctly
image = pipe(
prompt="A cat",
negative_prompt="blurry, low quality, distorted",
true_cfg_scale=3.5,
num_inference_steps=50,
).images[0]
```
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
## QwenImagePipeline
[[autodoc]] QwenImagePipeline

View File

@@ -37,8 +37,7 @@ The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.

View File

@@ -250,6 +250,9 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.

View File

@@ -1,11 +1,9 @@
# Remote inference
# Hybrid Inference API Reference
Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.
## remote_decode
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
## remote_encode
## Remote Encode
[[autodoc]] utils.remote_utils.remote_encode

View File

@@ -10,296 +10,51 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Remote inference
# Hybrid Inference
**Empowering local AI builders with Hybrid Inference**
> [!TIP]
> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.
| Model | Endpoint | Checkpoint | Support |
|---|---|---|---|
| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |
| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |
| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |
| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |
This guide will show you how to encode and decode latents with remote inference.
## Encoding
Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.
Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
```py
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_encode
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.float16,
vae=None,
device_map="cuda"
)
init_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
image=init_image,
scaling_factor=0.3611,
shift_factor=0.1159
)
```
## Decoding
Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.
Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
<hfoptions id="decode">
<hfoption id="Flux">
```py
from diffusers import FluxPipeline
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
vae=None,
device_map="cuda"
)
prompt = """
A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.
"""
latent = pipeline(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=4,
output_type="latent",
).images
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("image.jpg")
```
</hfoption>
<hfoption id="HunyuanVideo">
```py
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda"
)
latent = pipeline(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
output_type="latent",
).frames
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
output_type="mp4",
)
if isinstance(video, bytes):
with open("video.mp4", "wb") as f:
f.write(video)
```
</hfoption>
</hfoptions>
## Queuing
Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.
```py
import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionXLPipeline
def decode_worker(q: queue.Queue):
while True:
item = q.get()
if item is None:
break
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.13025,
)
display(image)
q.task_done()
q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()
def decode(latent: torch.Tensor):
q.put(latent)
prompts = [
"A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.",
"A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.",
"A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.",
"A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moons surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.",
"A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos."
]
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
pipeline = StableDiffusionXLPipeline.from_pretrained(
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
vae=None,
device_map="cuda"
)
pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
## Why use Hybrid Inference?
_ = pipeline(
prompt=prompts[0],
output_type="latent",
)
Hybrid Inference offers a fast and simple way to offload local generation requirements.
for prompt in prompts:
latent = pipeline(
prompt=prompt,
output_type="latent",
).images
decode(latent)
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
- 💰 **Cost Effective:** It's free! 🤑
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
q.put(None)
thread.join()
```
---
## Benchmarks
## Available Models
The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.
---
<details><summary>Encoding - Stable Diffusion v1.5</summary>
## Integrations
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
</details>
## Changelog
<details><summary>Encoding SDXL</summary>
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
## Contents
</details>
The documentation is organized into three sections:
<details><summary>Decoding - Stable Diffusion v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
</details>
<details><summary>Decoding SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
</details>
## Resources
- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).
- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more.
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.

View File

@@ -0,0 +1,345 @@
# Getting Started: VAE Decode with Hybrid Inference
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
## Memory
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_decode
```
### Basic example
Here, we show how to use the remote VAE on random tensors.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
scaling_factor=0.18215,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
</figure>
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
</figure>
Finally, an example for HunyuanVideo.
<details><summary>Code</summary>
```python
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
output_type="mp4",
)
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
</video>
</figure>
### Generation
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
<details><summary>Code</summary>
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
</figure>
Heres another example with Flux.
<details><summary>Code</summary>
```python
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=4,
output_type="latent",
).images
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
</figure>
Heres an example with HunyuanVideo.
<details><summary>Code</summary>
```python
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
).to("cuda")
latent = pipe(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
output_type="latent",
).frames
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
output_type="mp4",
)
if isinstance(video, bytes):
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
</video>
</figure>
### Queueing
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
<details><summary>Code</summary>
```python
import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionPipeline
def decode_worker(q: queue.Queue):
while True:
item = q.get()
if item is None:
break
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.18215,
)
display(image)
q.task_done()
q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()
def decode(latent: torch.Tensor):
q.put(latent)
prompts = [
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
]
pipe = StableDiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-8",
torch_dtype=torch.float16,
vae=None,
).to("cuda")
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
_ = pipe(
prompt=prompts[0],
output_type="latent",
)
for prompt in prompts:
latent = pipe(
prompt=prompt,
output_type="latent",
).images
decode(latent)
q.put(None)
thread.join()
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
</video>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

View File

@@ -0,0 +1,183 @@
# Getting Started: VAE Encode with Hybrid Inference
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
## Memory
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_encode
```
### Basic example
Let's encode an image, then decode it to demonstrate.
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>
<details><summary>Code</summary>
```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)
decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>
### Generation
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
<details><summary>Code</summary>
```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)
prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

View File

@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:

View File

@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
- It receives the iteration variable from the loop wrapper.
- It recieves the iteration variable from the loop wrapper.
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].

View File

@@ -68,20 +68,6 @@ config = FasterCacheConfig(
pipeline.transformer.enable_cache(config)
```
## FirstBlockCache
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
pipeline = DiffusionPipeline.from_pretrained(
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
)
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
```
## TaylorSeer Cache
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
@@ -101,7 +87,8 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
)
pipe.to("cuda")
config = TaylorSeerCacheConfig(
cache_interval=5,
@@ -110,4 +97,4 @@ config = TaylorSeerCacheConfig(
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
```
```

View File

@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
@@ -314,35 +314,6 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
### Unified Attention
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
This hybrid approach leverages the strengths of both methods:
- **Ulysses Attention** efficiently parallelizes across attention heads
- **Ring Attention** handles very long sequences with minimal memory overhead
- Together, they enable 2D parallelization across both heads and sequence dimensions
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
```
> [!TIP]
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|--------------------|------------------|-------------|------------------|
| ulysses | 6670.789 | 7.50 | 33.85 |
| ring | 13076.492 | 3.82 | 56.02 |
| unified_balanced | 11068.705 | 4.52 | 33.85 |
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.

View File

@@ -1929,8 +1929,6 @@ def main(args):
if args.cache_latents:
latents_cache = []
# Store vae config before potential deletion
vae_scaling_factor = vae.config.scaling_factor
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -1942,8 +1940,6 @@ def main(args):
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
vae_scaling_factor = vae.config.scaling_factor
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
@@ -2113,13 +2109,13 @@ def main(args):
model_input = vae.encode(pixel_values).latent_dist.sample()
if latents_mean is None and latents_std is None:
model_input = model_input * vae_scaling_factor
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents

View File

@@ -149,13 +149,13 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_images",
type=str,
default=None,
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
)
parser.add_argument(
"--validation_prompt_separator",

View File

@@ -140,7 +140,7 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_prompt_separator",

View File

@@ -21,8 +21,8 @@ from transformers import (
BertModel,
BertTokenizer,
CLIPImageProcessor,
MT5Tokenizer,
T5EncoderModel,
T5Tokenizer,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=T5Tokenizer,
tokenizer_2=MT5Tokenizer,
):
super().__init__()

View File

@@ -1,844 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import ZImageTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> init_image = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
>>> )
>>> mask = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
>>> )
>>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
>>> image = pipe(
... prompt,
... image=init_image,
... mask_image=mask,
... strength=0.75,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(41),
... ).images[0]
>>> image.save("image.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
r"""
The ZImage pipeline for image-to-image generation.
Args:
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`PreTrainedModel`]):
A text encoder model to encode text prompts.
tokenizer ([`AutoTokenizer`]):
A tokenizer to tokenize text prompts.
transformer ([`ZImageTransformer2DModel`]):
A ZImage transformer model to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
)
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype)
# Encode the input image
image = image.to(device=device, dtype=dtype)
if image.shape[1] != num_channels_latents:
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
else:
image_latents = image
# Handle batch size expansion
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
# Add noise using flow matching scale_noise
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
return latents, noise, image_latents, latent_image_ids
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
dtype,
device,
generator,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height, width))
mask = mask.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
masked_image = masked_image.to(device=device, dtype=dtype)
if masked_image.shape[1] == 16:
masked_image_latents = masked_image
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
strength: float = 0.6,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for image-to-image generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask
are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
strength (`float`, *optional*, defaults to 0.6):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image. If not provided, uses the input image height.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image. If not provided, uses the input image width.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
# 1. Check inputs and validate strength
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
# 2. Preprocess image
init_image = self.image_processor.preprocess(image)
init_image = init_image.to(dtype=torch.float32)
# Get dimensions from the preprocessed image if not specified
if height is None:
height = init_image.shape[-2]
if width is None:
width = init_image.shape[-1]
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
# Calculate latent dimensions for image_seq_len
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
# 6. Adjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(actual_batch_size)
# 7. Prepare latents from image
latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
actual_batch_size,
num_channels_latents,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
latents,
)
resize_mode = "default"
crops_coords = None
# start diff diff preparation
original_mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
masked_image = init_image * original_mask
original_mask, _ = self.prepare_mask_latents(
original_mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
)
mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
masks = original_mask > mask_thresholds
# end diff diff preparation
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
# start diff diff
image_latent = original_image_latents
latents_dtype = latents.dtype
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.scale_noise(
original_image_latents, torch.tensor([noise_timestep]), noise
)
mask = masks[i].to(latents_dtype)
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -1,22 +1,14 @@
# DreamBooth training example for FLUX.2 [dev] and FLUX 2 [klein]
# DreamBooth training example for FLUX.2 [dev]
[DreamBooth](https://huggingface.co/papers/2208.12242) is a method to personalize image generation models given just a few (3~5) images of a subject/concept.
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scripts shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://huggingface.co/black-forest-labs/FLUX.2-dev) and [FLUX 2 [klein]](https://huggingface.co/black-forest-labs/FLUX.2-klein).
> [!NOTE]
> **Model Variants**
>
> We support two FLUX model families:
> - **FLUX.2 [dev]**: The full-size model using Mistral Small 3.1 as the text encoder. Very capable but memory intensive.
> - **FLUX 2 [klein]**: Available in 4B and 9B parameter variants, using Qwen VL as the text encoder. Much more memory efficient and suitable for consumer hardware.
The `train_dreambooth_lora_flux2.py` script shows how to implement the training procedure for [LoRAs](https://huggingface.co/blog/lora) and adapt it for [FLUX.2 [dev]](https://github.com/black-forest-labs/flux2).
> [!NOTE]
> **Memory consumption**
>
> FLUX.2 [dev] can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. FLUX 2 [klein] models (4B and 9B) are significantly more memory efficient alternatives. Below we provide some tips and tricks to reduce memory consumption during training.
> Flux can be quite expensive to run on consumer hardware devices and as a result finetuning it comes with high memory requirements -
> a LoRA with a rank of 16 can exceed XXGB of VRAM for training. below we provide some tips and tricks to reduce memory consumption during training.
> For more tips & guidance on training on a resource-constrained device and general good practices please check out these great guides and trainers for FLUX:
> 1) [`@bghira`'s guide](https://github.com/bghira/SimpleTuner/blob/main/documentation/quickstart/FLUX2.md)
@@ -25,7 +17,7 @@ The `train_dreambooth_lora_flux2.py`, `train_dreambooth_lora_flux2_klein.py` scr
> [!NOTE]
> **Gated model**
>
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you've accepted the gate. Use the command below to log in:
> As the model is gated, before using it with diffusers you first need to go to the [FLUX.2 [dev] Hugging Face page](https://huggingface.co/black-forest-labs/FLUX.2-dev), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows youve accepted the gate. Use the command below to log in:
```bash
hf auth login
@@ -96,32 +88,20 @@ snapshot_download(
This will also allow us to push the trained LoRA parameters to the Hugging Face Hub platform.
As mentioned, Flux2 LoRA training is *very* memory intensive (especially for FLUX.2 [dev]). Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
As mentioned, Flux2 LoRA training is *very* memory intensive. Here are memory optimizations we can use (some still experimental) for a more memory efficient training:
## Memory Optimizations
> [!NOTE] many of these techniques complement each other and can be used together to further reduce memory consumption.
> However some techniques may be mutually exclusive so be sure to check before launching a training run.
### Remote Text Encoder
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--remote_text_encoder` flag to enable remote computation of the prompt embeddings using the HuggingFace Inference API.
This way, the text encoder model is not loaded into memory during training.
> [!IMPORTANT]
> **Remote text encoder is only supported for FLUX.2 [dev]**. FLUX 2 [klein] models use the Qwen VL text encoder and do not support remote text encoding.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
FLUX.2 [dev] uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
Pre-encode the training images with the vae, and then delete it to free up some memory. To enable `latent_caching` simply pass `--cache_latents`.
### QLoRA: Low Precision Training with Quantization
Perform low precision training using 8-bit or 4-bit quantization to reduce memory usage. You can use the following flags:
- **FP8 training** with `torchao`:
@@ -131,29 +111,22 @@ enable FP8 training by passing `--do_fp8_training`.
- **NF4 training** with `bitsandbytes`:
Alternatively, you can use 8-bit or 4-bit quantization with `bitsandbytes` by passing:
`--bnb_quantization_config_path` to enable 4-bit NF4 quantization.
### Gradient Checkpointing and Accumulation
* `--gradient accumulation` refers to the number of updates steps to accumulate before performing a backward/update pass.
by passing a value > 1 you can reduce the amount of backward/update passes and hence also memory reqs.
* with `--gradient checkpointing` we can save memory by not storing all intermediate activations during the forward pass.
Instead, only a subset of these activations (the checkpoints) are stored and the rest is recomputed as needed during the backward pass. Note that this comes at the expanse of a slower backward pass.
### 8-bit-Adam Optimizer
When training with `AdamW`(doesn't apply to `prodigy`) You can pass `--use_8bit_adam` to reduce the memory requirements of training.
Make sure to install `bitsandbytes` if you want to do so.
### Image Resolution
An easy way to mitigate some of the memory requirements is through `--resolution`. `--resolution` refers to the resolution for input images, all the images in the train/validation dataset are resized to this.
Note that by default, images are resized to resolution of 512, but it's good to keep in mind in case you're accustomed to training on higher resolutions.
### Precision of saved LoRA layers
By default, trained transformer layers are saved in the precision dtype in which training was performed. E.g. when training in mixed precision is enabled with `--mixed_precision="bf16"`, final finetuned layers will be saved in `torch.bfloat16` as well.
This reduces memory requirements significantly w/o a significant quality loss. Note that if you do wish to save the final layers in float32 at the expanse of more memory usage, you can do so by passing `--upcast_before_saving`.
## Training Examples
### FLUX.2 [dev] Training
To perform DreamBooth with LoRA on FLUX.2 [dev], run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -185,104 +158,19 @@ accelerate launch train_dreambooth_lora_flux2.py \
--push_to_hub
```
### FLUX 2 [klein] Training
FLUX 2 [klein] models are more memory efficient alternatives available in 4B and 9B parameter variants. They use the Qwen VL text encoder instead of Mistral Small 3.1.
> [!NOTE]
> The `--remote_text_encoder` flag is **not supported** for FLUX 2 [klein] models. The Qwen VL text encoder must be loaded locally, but offloading is still supported.
**FLUX 2 [klein] 4B:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-klein-4B"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-klein-4b"
accelerate launch train_dreambooth_lora_flux2_klein.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
**FLUX 2 [klein] 9B:**
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-klein-9B"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux2-klein-9b"
accelerate launch train_dreambooth_lora_flux2_klein.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--do_fp8_training \
--gradient_checkpointing \
--cache_latents \
--instance_prompt="a photo of sks dog" \
--resolution=1024 \
--train_batch_size=1 \
--guidance_scale=1 \
--use_8bit_adam \
--gradient_accumulation_steps=4 \
--optimizer="adamW" \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=100 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25 \
--seed="0" \
--push_to_hub
```
To better track our training experiments, we're using the following flags in the command above:
* `report_to="wandb` will ensure the training runs are tracked on [Weights and Biases](https://wandb.ai/site). To use it, be sure to install `wandb` with `pip install wandb`. Don't forget to call `wandb login <your_api_key>` before training if you haven't done it before.
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
> [!NOTE]
> If you want to train using long prompts, you can use `--max_sequence_length` to set the token limit. Note that this will use more resources and may slow down the training in some cases.
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
## LoRA + DreamBooth
```shell
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
### Prodigy Optimizer
Prodigy is an adaptive optimizer that dynamically adjusts the learning rate learned parameters based on past gradients, allowing for more efficient convergence.
@@ -295,6 +183,8 @@ to use prodigy, first make sure to install the prodigyopt library: `pip install
> [!TIP]
> When using prodigy it's generally good practice to set- `--learning_rate=1.0`
To perform DreamBooth with LoRA, run:
```bash
export MODEL_NAME="black-forest-labs/FLUX.2-dev"
export INSTANCE_DIR="dog"
@@ -358,10 +248,13 @@ the exact modules for LoRA training. Here are some examples of target modules yo
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
## Training Image-to-Image
Flux.2 lets us perform image editing as well as image generation. We provide a simple script for image-to-image(I2I) LoRA fine-tuning in [train_dreambooth_lora_flux2_img2img.py](./train_dreambooth_lora_flux2_img2img.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
**important**
**Important**
To make sure you can successfully run the latest version of the image-to-image example script, we highly recommend installing from source, specifically from the commit mentioned below. To do this, execute the following steps in a new virtual environment:
@@ -418,6 +311,5 @@ we've added aspect ratio bucketing support which allows training on images with
To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a semicolon-separated list of height,width pairs, such as:
`--aspect_ratio_buckets="672,1568;688,1504;720,1456;752,1392;800,1328;832,1248;880,1184;944,1104;1024,1024;1104,944;1184,880;1248,832;1328,800;1392,752;1456,720;1504,688;1568,672"
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
`
Since Flux.2 finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗

View File

@@ -111,25 +111,6 @@ To better track our training experiments, we're using the following flags in the
## Notes
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Additional CLI arguments
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.

View File

@@ -1,262 +0,0 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
import sys
import tempfile
import safetensors
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
sys.path.append("..")
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger()
stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
class DreamBoothLoRAFlux2Klein(ExamplesTestsAccelerate):
instance_data_dir = "docs/source/en/imgs"
instance_prompt = "dog"
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux2-klein"
script_path = "examples/dreambooth/train_dreambooth_lora_flux2_klein.py"
transformer_layer_type = "single_transformer_blocks.0.attn.to_qkv_mlp_proj"
def test_dreambooth_lora_flux2(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_latent_caching(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names.
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_layers(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--cache_latents
--learning_rate 5.0e-04
--scale_lr
--lora_layers {self.transformer_layer_type}
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"transformer"` in their names. In this test, we only params of
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
starts_with_transformer = all(
key.startswith(f"transformer.{self.transformer_layer_type}") for key in lora_state_dict.keys()
)
self.assertTrue(starts_with_transformer)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=6
--checkpoints_total_limit=2
--max_sequence_length 8
--checkpointing_steps=2
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
{"checkpoint-4", "checkpoint-6"},
)
def test_dreambooth_lora_flux2_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=4
--checkpointing_steps=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + test_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
resume_run_args = f"""
{self.script_path}
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
--instance_data_dir={self.instance_data_dir}
--output_dir={tmpdir}
--instance_prompt={self.instance_prompt}
--resolution=64
--train_batch_size=1
--gradient_accumulation_steps=1
--max_train_steps=8
--checkpointing_steps=2
--resume_from_checkpoint=checkpoint-4
--checkpoints_total_limit=2
--max_sequence_length 8
--text_encoder_out_layers 1
""".split()
run_command(self._launch_args + resume_run_args)
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
def test_dreambooth_lora_with_metadata(self):
# Use a `lora_alpha` that is different from `rank`.
lora_alpha = 8
rank = 4
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
{self.script_path}
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
--instance_data_dir {self.instance_data_dir}
--instance_prompt {self.instance_prompt}
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--lora_alpha={lora_alpha}
--rank={rank}
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--max_sequence_length 8
--text_encoder_out_layers 1
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
self.assertTrue(os.path.isfile(state_dict_file))
# Check if the metadata was properly serialized.
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
metadata = f.metadata() or {}
metadata.pop("format", None)
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
if raw:
raw = json.loads(raw)
loaded_lora_alpha = raw["transformer.lora_alpha"]
self.assertTrue(loaded_lora_alpha == lora_alpha)
loaded_lora_rank = raw["transformer.r"]
self.assertTrue(loaded_lora_rank == rank)

View File

@@ -44,7 +44,6 @@ import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -76,16 +75,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -729,7 +722,6 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1227,11 +1219,7 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1275,42 +1263,17 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1322,20 +1285,13 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not is_fsdp:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1551,21 +1507,6 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1595,8 +1536,6 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1838,7 +1777,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or is_fsdp:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1897,41 +1836,15 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -43,7 +43,6 @@ import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -75,16 +74,13 @@ from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -127,7 +120,7 @@ def save_model_card(
)
model_description = f"""
# Flux.2 DreamBooth LoRA - {repo_id}
# Flux DreamBooth LoRA - {repo_id}
<Gallery />
@@ -346,7 +339,7 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=False,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -698,7 +691,6 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -835,28 +827,15 @@ class DreamBoothDataset(Dataset):
dest_image = self.cond_images[i]
image_width, image_height = dest_image.size
if image_width * image_height > 1024 * 1024:
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
image_width, image_height = dest_image.size
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
image_processor = Flux2ImageProcessor()
dest_image = image_processor.preprocess(
dest_image = Flux2ImageProcessor.image_processor.preprocess(
dest_image, height=image_height, width=image_width, resize_mode="crop"
)
# Convert back to PIL
dest_image = dest_image.squeeze(0)
if dest_image.min() < 0:
dest_image = (dest_image + 1) / 2
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
if dest_image.shape[0] == 1:
# Gray scale image
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
else:
# RGB scale image: (C, H, W) -> (H, W, C)
dest_image = TF.to_pil_image(dest_image)
dest_image = exif_transpose(dest_image)
if not dest_image.mode == "RGB":
@@ -1177,11 +1156,7 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1225,42 +1200,17 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1272,20 +1222,13 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not is_fsdp:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1476,9 +1419,9 @@ def main(args):
args.instance_prompt, text_encoding_pipeline
)
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.validation_prompt is not None:
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.remote_text_encoder:
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
else:
@@ -1487,21 +1430,6 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1533,8 +1461,6 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1695,13 +1621,9 @@ def main(args):
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1728,9 +1650,6 @@ def main(args):
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
orig_input_shape = packed_noisy_model_input.shape
orig_input_ids_shape = model_input_ids.shape
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1749,8 +1668,7 @@ def main(args):
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : orig_input_shape[1], :]
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
@@ -1782,7 +1700,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or is_fsdp:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1841,41 +1759,15 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

File diff suppressed because it is too large Load Diff

View File

@@ -1513,12 +1513,14 @@ def main(args):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(

View File

@@ -1,157 +0,0 @@
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
## Overview
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
- Loss of fine details in generated images
- Inconsistent image quality
- Structural artifacts
- Reduced sharpness and realism
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
- Improved image quality and consistency (6-20% FID improvement)
- Better preservation of fine details
- More stable training, especially at high noise levels
- Better handling of structural information
- Sharper and more realistic textures
## Implementation Details
The LPL implementation follows the paper's methodology and includes several key features:
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
- Middle block features
- Up block features (configurable number of blocks)
- Proper gradient checkpointing for memory efficiency
- Features are extracted only for timesteps below the threshold (high SNR)
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
- `default`: Normalize each feature map independently
- `shared`: Cross-normalize features using target statistics (recommended)
- `batch`: Batch-wise normalization
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
- Quantile-based filtering (2% quantiles)
- Morphological operations (opening/closing)
- Adaptive thresholding based on standard deviation
4. **Loss Types**:
- MSE loss (default)
- L1 loss
- Optional power law weighting (2^(-i) for layer i)
## Usage
To use LPL in your training, add the following arguments to your training command:
```bash
python examples/research_projects/lpl/train_sdxl_lpl.py \
--use_lpl \
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
--lpl_pow_law \ # Use power law weighting for layers
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
--lpl_remove_outliers \ # Remove outliers in feature maps
--lpl_scale \ # Scale LPL loss by noise level weights
--lpl_start 0 \ # Step to start applying LPL
# ... other training arguments ...
```
### Key Parameters
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
## Recommendations
1. **Starting Point** (based on paper results):
```bash
--use_lpl \
--lpl_weight 1.0 \
--lpl_t_threshold 200 \
--lpl_loss_type mse \
--lpl_norm_type shared \
--lpl_pow_law \
--lpl_num_blocks 4 \
--lpl_remove_outliers \
--lpl_scale
```
2. **Memory Efficiency**:
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
- Consider using `--lpl_scale` to focus on more important timesteps
- Features are extracted only for timesteps below threshold to save memory
3. **Quality vs Speed**:
- Higher `lpl_weight` (1.0-2.0) for better quality
- Lower `lpl_t_threshold` (100-200) for faster training
- Use `lpl_remove_outliers` for more stable training
- `lpl_norm_type shared` provides best quality/speed trade-off
## Technical Details
### Feature Extraction
The LPL implementation extracts features from the VAE decoder in the following order:
1. Middle block output
2. Up block outputs (configurable number of blocks)
Each feature map is processed with:
1. Optional outlier removal (2% quantiles, morphological operations)
2. Feature normalization (shared statistics recommended)
3. Loss calculation (MSE or L1)
4. Optional power law weighting (2^(-i) for layer i)
### Loss Calculation
For each feature map:
1. Features are normalized according to the chosen strategy
2. Loss is calculated between normalized features
3. Outliers are masked out (if enabled)
4. Loss is weighted by layer depth (if power law enabled)
5. Final loss is averaged across all layers
### Memory Considerations
- Gradient checkpointing is used by default
- Features are extracted only for timesteps below the threshold
- Outlier removal is done in-place to save memory
- Feature normalization is done efficiently using vectorized operations
- Memory usage scales linearly with number of blocks used
## Results
Based on the paper's findings, LPL provides:
- 6-20% improvement in FID scores
- Better preservation of fine details
- More realistic textures and structures
- Improved consistency across different resolutions
- Better performance on both small and large datasets
## Citation
If you use this implementation in your research, please cite:
```bibtex
@inproceedings{berrada2025boosting,
title={Boosting Latent Diffusion with Perceptual Objectives},
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=y4DtzADzd1}
}
```

View File

@@ -1,215 +0,0 @@
# Copyright 2025 Berrada et al.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def cross_normalize(input, target, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
return input / (norm_factor + eps), target / (norm_factor + eps)
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
opening = int(np.ceil(opening / down_f))
closing = int(np.ceil(closing / down_f))
if opening == 2:
opening = 3
if closing == 2:
closing = 1
# replace quantile with kth value here.
feat_flat = feat.flatten(-2, -1)
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
m = 2 * feat_flat.std(-1)[..., None, None].detach()
mask = (q1 - m < feat) * (feat < q2 + m)
# dilate the mask.
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
feat = feat * mask
return mask, feat
class LatentPerceptualLoss(nn.Module):
def __init__(
self,
vae,
loss_type="mse",
grad_ckpt=True,
pow_law=False,
norm_type="default",
num_mid_blocks=4,
feature_type="feature",
remove_outliers=True,
):
super().__init__()
self.vae = vae
self.decoder = self.vae.decoder
# Store scaling factors as tensors on the correct device
device = next(self.vae.parameters()).device
# Get scaling factors with proper defaults and handle None values
scale_factor = getattr(self.vae.config, "scaling_factor", None)
shift_factor = getattr(self.vae.config, "shift_factor", None)
# Convert to tensors with proper defaults
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
self.gradient_checkpointing = grad_ckpt
self.pow_law = pow_law
self.norm_type = norm_type.lower()
self.outlier_mask = remove_outliers
self.last_feature_stats = [] # Store feature statistics for logging
assert feature_type in ["feature", "image"]
self.feature_type = feature_type
assert self.norm_type in ["default", "shared", "batch"]
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
self.n_blocks = num_mid_blocks
assert loss_type in ["mse", "l1"]
if loss_type == "mse":
self.loss_fn = nn.MSELoss(reduction="none")
elif loss_type == "l1":
self.loss_fn = nn.L1Loss(reduction="none")
def get_features(self, z, latent_embeds=None, disable_grads=False):
with torch.set_grad_enabled(not disable_grads):
if self.gradient_checkpointing and not disable_grads:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
features = []
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
sample = z
sample = self.decoder.conv_in(sample)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.decoder.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
features.append(sample)
# up
for up_block in self.decoder.up_blocks[: self.n_blocks]:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
features.append(sample)
return features
else:
features = []
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
sample = z
sample = self.decoder.conv_in(sample)
# middle
sample = self.decoder.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
features.append(sample)
# up
for up_block in self.decoder.up_blocks[: self.n_blocks]:
sample = up_block(sample, latent_embeds)
features.append(sample)
return features
def get_loss(self, input, target, get_hist=False):
if self.feature_type == "feature":
inp_f = self.get_features(self.shift + input / self.scale)
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
losses = []
self.last_feature_stats = [] # Reset feature stats
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
my = torch.ones_like(y).bool()
outlier_ratio = 0.0
if self.outlier_mask:
with torch.no_grad():
if i == 2:
my, y = remove_outliers(y, down_f=2)
outlier_ratio = 1.0 - my.float().mean().item()
elif i in [3, 4, 5]:
my, y = remove_outliers(y, down_f=1)
outlier_ratio = 1.0 - my.float().mean().item()
# Store feature statistics before normalization
with torch.no_grad():
stats = {
"mean": y.mean().item(),
"std": y.std().item(),
"outlier_ratio": outlier_ratio,
}
self.last_feature_stats.append(stats)
# normalize feature tensors
if self.norm_type == "default":
x = normalize_tensor(x)
y = normalize_tensor(y)
elif self.norm_type == "shared":
x, y = cross_normalize(x, y, eps=1e-6)
term_loss = self.loss_fn(x, y) * my
# reduce loss term
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
losses.append(term_loss.mean((1,)))
if get_hist:
return losses
else:
loss = sum(losses)
return loss / len(inp_f)
elif self.feature_type == "image":
inp_f = self.vae.decode(input / self.scale).sample
tar_f = self.vae.decode(target / self.scale).sample
return F.mse_loss(inp_f, tar_f)
def get_first_conv(self, z):
sample = self.decoder.conv_in(z)
return sample
def get_first_block(self, z):
sample = self.decoder.conv_in(z)
sample = self.decoder.mid_block(sample)
for resnet in self.decoder.up_blocks[0].resnets:
sample = resnet(sample, None)
return sample
def get_first_layer(self, input, target, target_layer="conv"):
if target_layer == "conv":
feat_in = self.get_first_conv(input)
with torch.no_grad():
feat_tar = self.get_first_conv(target)
else:
feat_in = self.get_first_block(input)
with torch.no_grad():
feat_tar = self.get_first_block(target)
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
return F.mse_loss(feat_in, feat_tar, reduction="mean")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
___Note___:
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
## Running locally with PyTorch

View File

@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
NUM_DEVICES = jax.device_count()
# 1. Let's start by downloading the model and loading it into our pipeline class
# Adhering to JAX's functional approach, the model's parameters are returned separately and
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
# will have to be passed to the pipeline during inference
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True

View File

@@ -7,12 +7,16 @@ import torch
from diffusers.utils import logging
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
logger = logging.get_logger(__name__)
def safe_tokenize(tokenizer, *args, lock, **kwargs):
with lock:
return tokenizer(*args, **kwargs)
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
@@ -34,40 +38,23 @@ class RequestScopedPipeline:
wrap_scheduler: bool = True,
):
self._base = pipeline
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
self.transformer = getattr(pipeline, "transformer", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._vae_lock = threading.Lock()
self._image_lock = threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []
def _detect_kernel_pipeline(self, pipeline) -> bool:
kernel_indicators = [
"text_encoding_cache",
"memory_manager",
"enable_optimizations",
"_create_request_context",
"get_optimization_stats",
]
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
@@ -83,21 +70,11 @@ class RequestScopedPipeline:
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
try:
if hasattr(wrapped_scheduler, "scheduler"):
try:
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception:
return wrapped_scheduler
else:
copied_scheduler = copy.copy(wrapped_scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception as e2:
logger.warning(
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
)
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
return wrapped_scheduler
def _autodetect_mutables(self, max_attrs: int = 40):
@@ -109,7 +86,6 @@ class RequestScopedPipeline:
candidates: List[str] = []
seen = set()
for name in dir(self._base):
if name.startswith("__"):
continue
@@ -117,7 +93,6 @@ class RequestScopedPipeline:
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue
try:
val = getattr(self._base, name)
except Exception:
@@ -125,9 +100,11 @@ class RequestScopedPipeline:
import types
# skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue
# containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
@@ -228,9 +205,6 @@ class RequestScopedPipeline:
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
def _should_wrap_tokenizers(self) -> bool:
return True
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
@@ -240,25 +214,6 @@ class RequestScopedPipeline:
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)
try:
if (
hasattr(local_pipe, "vae")
and local_pipe.vae is not None
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
):
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
if (
hasattr(local_pipe, "image_processor")
and local_pipe.image_processor is not None
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
):
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
local_pipe.image_processor, self._image_lock
)
except Exception as e:
logger.debug(f"Could not wrap vae/image_processor: {e}")
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
@@ -276,42 +231,47 @@ class RequestScopedPipeline:
self._clone_mutable_attrs(self._base, local_pipe)
original_tokenizers = {}
# 4) wrap tokenizers on the local pipe with the lock wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(
tok, *args, lock=self._tokenizer_lock, **kwargs
),
)
if self._should_wrap_tokenizers():
try:
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
if not isinstance(tok, ThreadSafeTokenizerWrapper):
original_tokenizers[name] = tok
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
setattr(local_pipe, name, wrapped_tokenizer)
# b) wrap tokenizers in components dict
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
if self._is_tokenizer_component(val):
if not isinstance(val, ThreadSafeTokenizerWrapper):
original_tokenizers[f"components[{key}]"] = val
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
local_pipe.components[key] = wrapped_tokenizer
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
@@ -319,18 +279,18 @@ class RequestScopedPipeline:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
for name, tok in original_tokenizers.items():
for name, tok in tokenizer_wrappers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
local_pipe.components[key] = tok
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring original tokenizers: {e}")
logger.debug(f"Error restoring wrapped tokenizers: {e}")

View File

@@ -1,86 +0,0 @@
class ThreadSafeTokenizerWrapper:
def __init__(self, tokenizer, lock):
self._tokenizer = tokenizer
self._lock = lock
self._thread_safe_methods = {
"__call__",
"encode",
"decode",
"tokenize",
"encode_plus",
"batch_encode_plus",
"batch_decode",
}
def __getattr__(self, name):
attr = getattr(self._tokenizer, name)
if name in self._thread_safe_methods and callable(attr):
def wrapped_method(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped_method
return attr
def __call__(self, *args, **kwargs):
with self._lock:
return self._tokenizer(*args, **kwargs)
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._tokenizer, name, value)
def __dir__(self):
return dir(self._tokenizer)
class ThreadSafeVAEWrapper:
def __init__(self, vae, lock):
self._vae = vae
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._vae, name)
if name in {"decode", "encode", "forward"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._vae, name, value)
class ThreadSafeImageProcessorWrapper:
def __init__(self, proc, lock):
self._proc = proc
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._proc, name)
if name in {"postprocess", "preprocess"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._proc, name, value)

View File

@@ -1,94 +1,11 @@
"""
# Cosmos 2 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2-2B-Text2Image
```
convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_ckpt_path $transformer_ckpt_path \
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
--text_encoder_path google-t5/t5-11b \
--tokenizer_path google-t5/t5-11b \
--vae_type wan2.1 \
--output_path converted/cosmos-p2-t2i-2b \
--save_pipeline
```
# Cosmos 2.5 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2.5-2B
```
Convert checkpoint
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
--save_pipeline
```
## 14B
```bash
hf download nvidia/Cosmos-Predict2.5-14B
```
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
--save_pipeline
```
"""
import argparse
import pathlib
import sys
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKLCosmos,
@@ -100,9 +17,7 @@ from diffusers import (
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -318,44 +233,6 @@ TRANSFORMER_CONFIGS = {
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.5-Predict-Base-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
"Cosmos-2.5-Predict-Base-14B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
VAE_KEYS_RENAME_DICT = {
@@ -457,9 +334,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
elif "Cosmos-2.5" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False
@@ -473,7 +347,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
@@ -482,21 +355,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
continue
handler_fn_inplace(key, original_state_dict)
expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
for k in missing_keys:
print(k)
sys.exit(1)
if unexpected_keys:
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
for k in unexpected_keys:
print(k)
sys.exit(2)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
@@ -586,34 +444,6 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)
pipe = Cosmos2_5_PredictBasePipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
@@ -621,10 +451,10 @@ def get_args():
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -647,6 +477,8 @@ if __name__ == "__main__":
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -658,26 +490,17 @@ if __name__ == "__main__":
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
else:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
else:
raise AssertionError(f"{args.transformer_type} not supported")
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
else:
raise AssertionError(f"{args.transformer_type} not supported")
assert False

View File

@@ -44,7 +44,7 @@ CTX = init_empty_weights if is_accelerate_available() else nullcontext
parser = argparse.ArgumentParser()
parser.add_argument("--original_state_dict_repo_id", default=None, type=str)
parser.add_argument("--vae_filename", default="flux2-vae.sft", type=str)
parser.add_argument("--dit_filename", default="flux2-dev.safetensors", type=str)
parser.add_argument("--dit_filename", default="flux-dev-dummy.sft", type=str)
parser.add_argument("--vae", action="store_true")
parser.add_argument("--dit", action="store_true")
parser.add_argument("--vae_dtype", type=str, default="fp32")
@@ -385,9 +385,9 @@ def update_state_dict(state_dict: Dict[str, Any], old_key: str, new_key: str) ->
def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
if model_type == "flux2-dev":
if model_type == "test" or model_type == "dummy-flux2":
config = {
"model_id": "black-forest-labs/FLUX.2-dev",
"model_id": "diffusers-internal-dev/dummy-flux2",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
@@ -405,53 +405,6 @@ def get_flux2_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "klein-4b":
config = {
"model_id": "diffusers-internal-dev/dummy0115",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 5,
"num_single_layers": 20,
"attention_head_dim": 128,
"num_attention_heads": 24,
"joint_attention_dim": 7680,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
"guidance_embeds": False,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
elif model_type == "klein-9b":
config = {
"model_id": "diffusers-internal-dev/dummy0115",
"diffusers_config": {
"patch_size": 1,
"in_channels": 128,
"num_layers": 8,
"num_single_layers": 24,
"attention_head_dim": 128,
"num_attention_heads": 32,
"joint_attention_dim": 12288,
"timestep_guidance_channels": 256,
"mlp_ratio": 3.0,
"axes_dims_rope": (32, 32, 32, 32),
"rope_theta": 2000,
"eps": 1e-6,
"guidance_embeds": False,
},
}
rename_dict = FLUX2_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = TRANSFORMER_SPECIAL_KEYS_REMAP
else:
raise ValueError(f"Unknown model_type: {model_type}. Choose from: flux2-dev, klein-4b, klein-9b")
return config, rename_dict, special_keys_remap
@@ -494,14 +447,7 @@ def main(args):
if args.dit:
original_dit_ckpt = load_original_checkpoint(args, filename=args.dit_filename)
if "klein-4b" in args.dit_filename:
model_type = "klein-4b"
elif "klein-9b" in args.dit_filename:
model_type = "klein-9b"
else:
model_type = "flux2-dev"
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, model_type)
transformer = convert_flux2_transformer_to_diffusers(original_dit_ckpt, "test")
if not args.full_pipe:
dit_dtype = torch.bfloat16 if args.dit_dtype == "bf16" else torch.float32
transformer.to(dit_dtype).save_pretrained(f"{args.output_path}/transformer")
@@ -519,15 +465,8 @@ def main(args):
"black-forest-labs/FLUX.1-dev", subfolder="scheduler"
)
if_distilled = "base" not in args.dit_filename
pipe = Flux2Pipeline(
vae=vae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
if_distilled=if_distilled,
vae=vae, transformer=transformer, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler
)
pipe.save_pretrained(args.output_path)

View File

@@ -1,886 +0,0 @@
import argparse
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 2,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 16,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
ckpt_path = filename
_, ext = os.path.splitext(ckpt_path)
if ext in [".safetensors", ".sft"]:
state_dict = safetensors.torch.load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
return state_dict
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
)
parser.add_argument(
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
help="Version of the LTX 2.0 model",
)
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
parser.add_argument(
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
)
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
)
parser.add_argument(
"--upsample_pipeline",
action="store_true",
help="Whether to save a latent upsampling pipeline",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
def main(args):
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any(
[
args.vae,
args.audio_vae,
args.dit,
args.vocoder,
args.text_encoder,
args.full_pipeline,
args.upsample_pipeline,
]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline or args.upsample_pipeline:
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None:
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
if not args.full_pipeline and not args.upsample_pipeline:
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.upsample_pipeline:
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
pipe.save_pretrained(
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -274,7 +274,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -23,7 +23,6 @@ from .utils import (
is_torchao_available,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
)
@@ -194,8 +193,6 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -226,7 +223,6 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -239,8 +235,6 @@ else:
"Kandinsky3UNet",
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -284,7 +278,6 @@ else:
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageControlNetModel",
"ZImageTransformer2DModel",
"attention_backend",
]
@@ -358,7 +351,6 @@ else:
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
"LCMScheduler",
"LTXEulerAncestralRFScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SASolverScheduler",
@@ -413,9 +405,6 @@ else:
_import_structure["modular_pipelines"].extend(
[
"Flux2AutoBlocks",
"Flux2KleinAutoBlocks",
"Flux2KleinBaseAutoBlocks",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
@@ -426,8 +415,6 @@ else:
"QwenImageEditModularPipeline",
"QwenImageEditPlusAutoBlocks",
"QwenImageEditPlusModularPipeline",
"QwenImageLayeredAutoBlocks",
"QwenImageLayeredModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
@@ -460,11 +447,9 @@ else:
"AuraFlowPipeline",
"BlipDiffusionControlNetPipeline",
"BlipDiffusionPipeline",
"BriaFiboEditPipeline",
"BriaFiboPipeline",
"BriaPipeline",
"ChromaImg2ImgPipeline",
"ChromaInpaintPipeline",
"ChromaPipeline",
"ChronoEditPipeline",
"CLIPImageProjection",
@@ -476,7 +461,6 @@ else:
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
@@ -485,7 +469,6 @@ else:
"EasyAnimateControlPipeline",
"EasyAnimateInpaintPipeline",
"EasyAnimatePipeline",
"Flux2KleinPipeline",
"Flux2Pipeline",
"FluxControlImg2ImgPipeline",
"FluxControlInpaintPipeline",
@@ -500,7 +483,6 @@ else:
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -550,13 +532,7 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
@@ -585,7 +561,6 @@ else:
"QwenImageEditPlusPipeline",
"QwenImageImg2ImgPipeline",
"QwenImageInpaintPipeline",
"QwenImageLayeredPipeline",
"QwenImagePipeline",
"ReduxImageEncoder",
"SanaControlNetPipeline",
@@ -691,10 +666,7 @@ else:
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImageControlNetInpaintPipeline",
"ZImageControlNetPipeline",
"ZImageImg2ImgPipeline",
"ZImageOmniPipeline",
"ZImagePipeline",
]
)
@@ -956,8 +928,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -988,7 +958,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1001,8 +970,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky3UNet,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1045,7 +1012,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageControlNetModel,
ZImageTransformer2DModel,
attention_backend,
)
@@ -1111,7 +1077,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
LTXEulerAncestralRFScheduler,
PNDMScheduler,
RePaintScheduler,
SASolverScheduler,
@@ -1149,9 +1114,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .modular_pipelines import (
Flux2AutoBlocks,
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
Flux2KleinModularPipeline,
Flux2ModularPipeline,
FluxAutoBlocks,
FluxKontextAutoBlocks,
@@ -1162,8 +1124,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageLayeredAutoBlocks,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
@@ -1192,11 +1152,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AudioLDM2UNet2DConditionModel,
AudioLDMPipeline,
AuraFlowPipeline,
BriaFiboEditPipeline,
BriaFiboPipeline,
BriaPipeline,
ChromaImg2ImgPipeline,
ChromaInpaintPipeline,
ChromaPipeline,
ChronoEditPipeline,
CLIPImageProjection,
@@ -1208,7 +1166,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
@@ -1217,7 +1174,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateControlPipeline,
EasyAnimateInpaintPipeline,
EasyAnimatePipeline,
Flux2KleinPipeline,
Flux2Pipeline,
FluxControlImg2ImgPipeline,
FluxControlInpaintPipeline,
@@ -1232,7 +1188,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -1282,13 +1237,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
@@ -1317,7 +1266,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageEditPlusPipeline,
QwenImageImg2ImgPipeline,
QwenImageInpaintPipeline,
QwenImageLayeredPipeline,
QwenImagePipeline,
ReduxImageEncoder,
SanaControlNetPipeline,
@@ -1421,10 +1369,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -25,7 +25,6 @@ if is_torch_available():
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .magnitude_aware_guidance import MagnitudeAwareGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance

View File

@@ -1,159 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class MagnitudeAwareGuidance(BaseGuidance):
"""
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
Args:
guidance_scale (`float`, defaults to `10.0`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
alpha (`float`, defaults to `8.0`):
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
guidance scale when the magnitude of the guidance update is large.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 10.0,
alpha: float = 8.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.alpha = alpha
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_mambo_g_enabled():
pred = pred_cond
else:
pred = mambo_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.alpha,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_mambo_g_enabled():
num_conditions += 1
return num_conditions
def _is_mambo_g_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def mambo_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
alpha: float = 8.0,
use_original_formulation: bool = False,
):
dim = list(range(1, len(pred_cond.shape)))
diff = pred_cond - pred_uncond
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
guidance_scale_final = (
guidance_scale * torch.exp(-alpha * ratio)
if use_original_formulation
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
)
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale_final * diff
return pred

View File

@@ -67,7 +67,6 @@ if is_torch_available():
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTX2LoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
@@ -122,7 +121,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,

View File

@@ -2140,54 +2140,6 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
return converted_state_dict
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
# Remove the prefix
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
if non_diffusers_prefix == "diffusion_model":
rename_dict = {
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
else:
rename_dict = {"aggregate_embed": "text_proj_in"}
# Apply renaming
renamed_state_dict = {}
for key, value in converted_state_dict.items():
new_key = key[:]
for old_pattern, new_pattern in rename_dict.items():
new_key = new_key.replace(old_pattern, new_pattern)
renamed_state_dict[new_key] = value
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
final_state_dict = {}
for key, value in renamed_state_dict.items():
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
final_state_dict[new_key] = value
elif key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
final_state_dict[new_key] = value
else:
final_state_dict[key] = value
# Add transformer prefix
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
return final_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:

View File

@@ -48,7 +48,6 @@ from .lora_conversion_utils import (
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
@@ -75,7 +74,6 @@ logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LTX2_CONNECTOR_NAME = "connectors"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
@@ -214,7 +212,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
@@ -641,7 +639,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
@@ -1081,7 +1079,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -1377,7 +1375,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -1659,7 +1657,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
transformer_lora_state_dict = {
k: state_dict.get(k)
@@ -2506,7 +2504,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -2703,7 +2701,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -2906,7 +2904,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3013,233 +3011,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class LTX2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
"""
_lora_loadable_modules = ["transformer", "connectors"]
transformer_name = TRANSFORMER_NAME
connectors_name = LTX2_CONNECTOR_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
final_state_dict = state_dict
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
if is_non_diffusers_format:
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
if has_connector:
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
state_dict, "text_embedding_projection"
)
final_state_dict.update(connectors_state_dict)
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
}
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
self.load_lora_into_transformer(
transformer_peft_state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if connectors_peft_state_dict:
self.load_lora_into_transformer(
connectors_peft_state_dict,
transformer=getattr(self, self.connectors_name)
if not hasattr(self, "connectors")
else self.connectors,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=self.connectors_name,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
prefix: str = "transformer",
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {prefix}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=prefix,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
@@ -3333,7 +3104,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3536,7 +3307,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3740,7 +3511,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3940,7 +3711,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -4194,7 +3965,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4471,7 +4242,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4691,7 +4462,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -4894,7 +4665,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5100,7 +4871,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5306,7 +5077,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5509,7 +5280,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,

View File

@@ -63,12 +63,9 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
@@ -478,7 +475,7 @@ class PeftAdapterMixin:
Args:
adapter_names (`List[str]` or `str`):
The names of the adapters to use.
weights (`Union[List[float], float]`, *optional*):
adapter_weights (`Union[List[float], float]`, *optional*):
The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
adapters.
@@ -495,7 +492,7 @@ class PeftAdapterMixin:
"jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
)
pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipeline.unet.set_adapters(["cinematic", "pixel"], weights=[0.5, 0.5])
pipeline.unet.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
```
"""
if not USE_PEFT_BACKEND:

View File

@@ -40,9 +40,6 @@ from .single_file_utils import (
convert_hunyuan_video_transformer_to_diffusers,
convert_ldm_unet_checkpoint,
convert_ldm_vae_checkpoint,
convert_ltx2_audio_vae_to_diffusers,
convert_ltx2_transformer_to_diffusers,
convert_ltx2_vae_to_diffusers,
convert_ltx_transformer_checkpoint_to_diffusers,
convert_ltx_vae_checkpoint_to_diffusers,
convert_lumina2_to_diffusers,
@@ -52,7 +49,6 @@ from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
convert_z_image_controlnet_checkpoint_to_diffusers,
convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
@@ -152,10 +148,6 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanAnimateTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
@@ -169,7 +161,7 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {
@@ -180,30 +172,11 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
"LTX2VideoTransformer3DModel": {
"checkpoint_mapping_fn": convert_ltx2_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLLTX2Video": {
"checkpoint_mapping_fn": convert_ltx2_vae_to_diffusers,
"default_subfolder": "vae",
},
"AutoencoderKLLTX2Audio": {
"checkpoint_mapping_fn": convert_ltx2_audio_vae_to_diffusers,
"default_subfolder": "audio_vae",
},
}
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
model_state_dict_keys = set(model_state_dict.keys())
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
is_match = model_state_dict_keys == checkpoint_state_dict_keys
return not (is_subset and is_match)
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
def _get_single_file_loadable_mapping_class(cls):

View File

@@ -112,8 +112,7 @@ CHECKPOINT_KEY_NAMES = {
"model.diffusion_model.transformer_blocks.27.scale_shift_table",
"patchify_proj.weight",
"transformer_blocks.27.scale_shift_table",
"vae.decoder.last_scale_shift_table", # 0.9.1, 0.9.5, 0.9.7, 0.9.8
"vae.decoder.up_blocks.9.res_blocks.0.conv1.conv.weight", # 0.9.0
"vae.per_channel_statistics.mean-of-means",
],
"autoencoder-dc": "decoder.stages.1.op_list.0.main.conv.conv.bias",
"autoencoder-dc-sana": "encoder.project_in.conv.bias",
@@ -121,12 +120,7 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"z-image-turbo": [
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
"layers.0.adaLN_modulation.0.weight",
],
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
"z-image-turbo": "cap_embedder.0.weight",
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
@@ -136,7 +130,6 @@ CHECKPOINT_KEY_NAMES = {
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"wan_vace": "vace_blocks.0.after_proj.bias",
"wan_animate": "motion_encoder.dec.direction.weight",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [
"net.x_embedder.proj.1.weight",
@@ -149,11 +142,6 @@ CHECKPOINT_KEY_NAMES = {
"net.pos_embedder.dim_spatial_range",
],
"flux2": ["model.diffusion_model.single_stream_modulation.lin.weight", "single_stream_modulation.lin.weight"],
"ltx2": [
"model.diffusion_model.av_ca_a2v_gate_adaln_single.emb.timestep_embedder.linear_1.weight",
"vae.per_channel_statistics.mean-of-means",
"audio_vae.per_channel_statistics.mean-of-means",
],
}
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -220,7 +208,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
@@ -233,10 +220,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
"ltx2-dev": {"pretrained_model_name_or_path": "Lightricks/LTX-2"},
}
# Use to configure model sample size when original config is provided
@@ -740,7 +723,10 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
elif (
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
):
model_type = "z-image-turbo"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -761,9 +747,6 @@ def infer_diffusers_model_type(checkpoint):
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
model_type = "wan-animate-14B"
elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
@@ -796,21 +779,6 @@ def infer_diffusers_model_type(checkpoint):
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
before_proj_weight = checkpoint.get("control_noise_refiner.0.before_proj.weight", None)
if before_proj_weight is None:
model_type = "z-image-turbo-controlnet-2.0"
elif before_proj_weight is not None and torch.all(before_proj_weight == 0.0):
model_type = "z-image-turbo-controlnet-2.0"
else:
model_type = "z-image-turbo-controlnet-2.1"
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx2"]):
model_type = "ltx2-dev"
else:
model_type = "v1"
@@ -3159,64 +3127,13 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
def generate_motion_encoder_mappings():
mappings = {
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
"motion_encoder.enc.fc": "motion_encoder.motion_network",
}
for i in range(7):
conv_idx = i + 1
mappings.update(
{
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
}
)
return mappings
def generate_face_adapter_mappings():
return {
"face_adapter.fuser_blocks": "face_adapter",
".k_norm.": ".norm_k.",
".q_norm.": ".norm_q.",
".linear1_q.": ".to_q.",
".linear2.": ".to_out.",
"conv1_local.conv": "conv1_local",
"conv2.conv": "conv2",
"conv3.conv": "conv3",
}
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
tensor = state_dict.pop(key)
split_idx = tensor.shape[0] // 2
new_key_1 = key.replace(split_pattern, target_keys[0])
new_key_2 = key.replace(split_pattern, target_keys[1])
state_dict[new_key_1] = tensor[:split_idx]
state_dict[new_key_2] = tensor[split_idx:]
def reshape_bias_handler(key, state_dict):
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
state_dict[key] = state_dict[key][0, :, 0, 0]
converted_state_dict = {}
# Strip model.diffusion_model prefix
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
# Base transformer mappings
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
@@ -3238,42 +3155,27 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# I2V model
# For the I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# VACE model
# For the VACE model
"before_proj": "proj_in",
"after_proj": "proj_out",
}
SPECIAL_KEYS_HANDLERS = {}
if any("face_adapter" in k for k in checkpoint.keys()):
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
if any("motion_encoder" in k for k in checkpoint.keys()):
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
for key in list(checkpoint.keys()):
reshape_bias_handler(key, checkpoint)
for key in list(checkpoint.keys()):
new_key = key
new_key = key[:]
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
converted_state_dict[new_key] = checkpoint.pop(key)
for key in list(converted_state_dict.keys()):
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
if pattern not in key:
continue
handler_fn(key, converted_state_dict, pattern, target_keys)
break
converted_state_dict[new_key] = checkpoint.pop(key)
return converted_state_dict
@@ -3940,7 +3842,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
"model.diffusion_model.": "",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3975,9 +3876,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
update_state_dict(converted_state_dict, key, new_key)
if "norm_final.weight" in converted_state_dict.keys():
_ = converted_state_dict.pop("norm_final.weight")
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
@@ -3987,175 +3885,3 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
if config["add_control_noise_refiner"] is None:
return checkpoint
elif config["add_control_noise_refiner"] == "control_noise_refiner":
return checkpoint
elif config["add_control_noise_refiner"] == "control_layers":
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
}
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
def convert_ltx2_transformer_to_diffusers(checkpoint, **kwargs):
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Transformer prefix
"model.diffusion_model.": "",
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulation Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Video VAE prefix
"vae.": "",
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict) -> None:
state_dict.pop(key)
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_VIDEO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
for special_key, handler_fn_inplace in LTX_2_0_VAE_SPECIAL_KEYS_REMAP.items():
if special_key not in key:
continue
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_ltx2_audio_vae_to_diffusers(checkpoint, **kwargs):
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
# Audio VAE prefix
"audio_vae.": "",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
def update_state_dict_inplace(state_dict, old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
# Handle official code --> diffusers key remapping via the remap dict
for key in list(converted_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in LTX_2_0_AUDIO_VAE_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(converted_state_dict, key, new_key)
return converted_state_dict

View File

@@ -41,8 +41,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -68,7 +66,6 @@ if is_torch_available():
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["embeddings"] = ["ImageProjection"]
@@ -98,16 +95,13 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video_framepack"] = ["HunyuanVideoFramepackTransformer3DModel"]
_import_structure["transformers.transformer_hunyuanimage"] = ["HunyuanImageTransformer2DModel"]
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -157,8 +151,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -188,7 +180,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3MultiControlNetModel,
SparseControlNetModel,
UNetControlNetXSModel,
ZImageControlNetModel,
)
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
@@ -209,7 +200,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
@@ -218,8 +208,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoTransformer3DModel,
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -90,6 +90,10 @@ class ContextParallelConfig:
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

View File

@@ -235,10 +235,6 @@ class _AttentionBackendRegistry:
def get_active_backend(cls):
return cls._active_backend, cls._backends[cls._active_backend]
@classmethod
def set_active_backend(cls, backend: str):
cls._active_backend = backend
@classmethod
def list_backends(cls):
return list(cls._backends.keys())
@@ -298,12 +294,12 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
_maybe_download_kernel_for_backend(backend)
old_backend = _AttentionBackendRegistry._active_backend
_AttentionBackendRegistry.set_active_backend(backend)
_AttentionBackendRegistry._active_backend = backend
try:
yield
finally:
_AttentionBackendRegistry.set_active_backend(old_backend)
_AttentionBackendRegistry._active_backend = old_backend
def dispatch_attention_fn(
@@ -352,7 +348,6 @@ def dispatch_attention_fn(
check(**kwargs)
kwargs = {k: v for k, v in kwargs.items() if k in _AttentionBackendRegistry._supported_arg_names[backend_name]}
return backend_fn(**kwargs)
@@ -1111,51 +1106,6 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
return out
# Not implemented yet.
def _npu_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
# ===== Context parallel =====
@@ -1182,103 +1132,6 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
return x
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
"""
Perform dimension sharding / reassembly across processes using _all_to_all_single.
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
head dimension flexibly by accepting scatter_idx and gather_idx.
Args:
x (torch.Tensor):
Input tensor. Expected shapes:
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
scatter_idx (int) :
Dimension along which the tensor is partitioned before all-to-all.
gather_idx (int):
Dimension along which the output is reassembled after all-to-all.
group :
Distributed process group for the Ulysses group.
Returns:
torch.Tensor: Tensor with globally exchanged dimensions.
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
"""
group_world_size = torch.distributed.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
# dimension and scatters head dimension
batch_size, seq_len_local, num_heads, head_dim = x.shape
seq_len = seq_len_local * group_world_size
num_heads_local = num_heads // group_world_size
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
x_temp = (
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
.transpose(0, 2)
.contiguous()
)
if group_world_size > 1:
out = _all_to_all_single(x_temp, group=group)
else:
out = x_temp
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
return out
elif scatter_idx == 1 and gather_idx == 2:
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
# scatters back the sequence dimension.
batch_size, seq_len, num_heads_local, head_dim = x.shape
num_heads = num_heads_local * group_world_size
seq_len_local = seq_len // group_world_size
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
x_temp = (
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
.permute(1, 3, 2, 0, 4)
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
)
if group_world_size > 1:
output = _all_to_all_single(x_temp, group)
else:
output = x_temp
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
return output
else:
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
class SeqAllToAllDim(torch.autograd.Function):
"""
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
for more info.
"""
@staticmethod
def forward(ctx, group, input, scatter_id=2, gather_id=1):
ctx.group = group
ctx.scatter_id = scatter_id
ctx.gather_id = gather_id
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
@staticmethod
def backward(ctx, grad_outputs):
grad_input = SeqAllToAllDim.apply(
ctx.group,
grad_outputs,
ctx.gather_id, # reversed
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1339,10 +1192,7 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1403,7 +1253,7 @@ class TemplatedRingAttention(torch.autograd.Function):
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1498,69 +1348,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
def _templated_unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
scatter_idx: int = 2,
gather_idx: int = 1,
):
"""
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
"""
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
ulysses_group = ulysses_mesh.get_group()
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
out = TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if return_lse:
context_layer, lse, *_ = out
else:
context_layer = out
# context_layer is of shape (B, S, H_LOCAL, D)
output = SeqAllToAllDim.apply(
ulysses_group,
context_layer,
gather_idx,
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
return (output, lse)
return output
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
def _templated_context_parallel_attention(
@@ -1578,31 +1366,15 @@ def _templated_context_parallel_attention(
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
):
if attn_mask is not None:
raise ValueError("Attention mask is not yet supported for templated attention.")
if is_causal:
raise ValueError("Causal attention is not yet supported for templated attention.")
if enable_gqa:
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if (
_parallel_config.context_parallel_config.ring_degree > 1
and _parallel_config.context_parallel_config.ulysses_degree > 1
):
return _templated_unified_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
if _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
@@ -1648,7 +1420,6 @@ def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1656,9 +1427,6 @@ def _flash_attention(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
if _parallel_config is None:
out = flash_attn_func(
q=query,
@@ -1701,7 +1469,6 @@ def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1709,9 +1476,6 @@ def _flash_attention_hub(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
@@ -1848,15 +1612,11 @@ def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
@@ -1876,7 +1636,6 @@ def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
@@ -1887,8 +1646,6 @@ def _flash_attention_3_hub(
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
@@ -2028,16 +1785,12 @@ def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
@@ -2128,43 +1881,6 @@ def _native_flex_attention(
return out
def _prepare_additive_attn_mask(
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
) -> torch.Tensor:
"""
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
Args:
attn_mask: 2D tensor [batch_size, seq_len_k]
- Boolean: True means attend, False means mask out
- Additive: 0.0 means attend, -inf means mask out
target_dtype: The dtype to convert the mask to (usually query.dtype)
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
Returns:
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
"""
# Check if the mask is boolean or already additive
if attn_mask.dtype == torch.bool:
# Convert boolean to additive: True -> 0.0, False -> -inf
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
# Convert to target dtype
attn_mask = attn_mask.to(dtype=target_dtype)
else:
# Already additive mask - just ensure correct dtype
attn_mask = attn_mask.to(dtype=target_dtype)
# Optionally reshape to 4D for broadcasting in attention mechanisms
if reshape_4d:
batch_size, seq_len_k = attn_mask.shape
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
return attn_mask
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
@@ -2184,19 +1900,6 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
# Reshape 2D mask to 4D for SDPA
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
# SDPA handles both boolean and additive masks correctly
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2325,7 +2028,6 @@ def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -2333,9 +2035,6 @@ def _native_flash_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
lse = None
if _parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2409,52 +2108,34 @@ def _native_math_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_npu_attention_forward_op,
backward_op=_npu_attention_backward_op,
_parallel_config=_parallel_config,
)
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
return out
@@ -2467,13 +2148,10 @@ def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for XLA attention")
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2497,14 +2175,11 @@ def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
if _parallel_config is None:
out = sageattn(
@@ -2548,14 +2223,11 @@ def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
if _parallel_config is None:
@@ -2637,14 +2309,11 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
@@ -2664,14 +2333,11 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
@@ -2691,14 +2357,11 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
@@ -2718,14 +2381,11 @@ def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
@@ -2763,34 +2423,10 @@ def _xformers_attention(
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
# Convert 2D mask to 4D for xformers
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
# Need memory alignment - create larger tensor and slice for alignment
original_seq_len = attn_mask.size(1)
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
# Create aligned 4D tensor and slice to ensure proper memory layout
aligned_mask = torch.zeros(
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
dtype=query.dtype,
device=query.device,
)
# Convert to 4D additive mask (handles both boolean and additive inputs)
mask_additive = _prepare_additive_attn_mask(
attn_mask, target_dtype=query.dtype
) # [batch, 1, 1, seq_len_k]
# Broadcast to [batch, heads, seq_q, seq_len_k]
aligned_mask[:, :, :, :original_seq_len] = mask_additive
# Mask out the padding (already -inf from zeros -> where with default)
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
# Slice to actual size with proper alignment
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
elif attn_mask.ndim == 4:
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:

View File

@@ -10,8 +10,6 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

View File

@@ -102,14 +102,14 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_multiscales: Tuple[int, ...] = (),
qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
)
else:
@@ -247,7 +247,7 @@ class Encoder(nn.Module):
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_multiscales=qkv_multiscales[i],
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
@@ -339,7 +339,7 @@ class Decoder(nn.Module):
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_multiscales=qkv_multiscales[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)

Some files were not shown because too many files have changed in this diff Show More