Compare commits

..

6 Commits

Author SHA1 Message Date
sayakpaul
49e39f8a4e make fix-copies 2025-12-17 16:01:07 +05:30
Sayak Paul
02851fade5 Merge branch 'main' into main 2025-12-17 18:26:09 +08:00
github-actions[bot]
c0b58138e5 Apply style fixes 2025-12-17 10:25:33 +00:00
naykun
f3c62427cb [qwen-image] fix pr comments 2025-12-17 18:14:05 +08:00
naykun
e630e6eca4 [qwen-image] update doc 2025-12-17 15:01:58 +08:00
naykun
a89f13eff4 [qwen-image] qwen image layered support 2025-12-17 15:01:57 +08:00
334 changed files with 5508 additions and 29622 deletions

View File

@@ -28,7 +28,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -58,7 +58,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}

View File

@@ -28,7 +28,7 @@ jobs:
uses: docker/setup-buildx-action@v1
- name: Check out code
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Find Changed Dockerfiles
id: file_changes
@@ -99,7 +99,7 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub

View File

@@ -17,10 +17,10 @@ jobs:
steps:
- name: Checkout repository
uses: actions/checkout@v6
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v5
with:
python-version: '3.10'

View File

@@ -1,22 +0,0 @@
---
name: CodeQL Security Analysis For Github Actions
on:
push:
branches: ["main"]
workflow_dispatch:
# pull_request:
jobs:
codeql:
name: CodeQL Analysis
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
permissions:
security-events: write
packages: read
actions: read
contents: read
with:
languages: '["actions","python"]'
queries: 'security-extended,security-and-quality'
runner: 'ubuntu-latest' #optional if need custom runner

View File

@@ -24,6 +24,7 @@ jobs:
mirror_community_pipeline:
env:
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
runs-on: ubuntu-22.04
steps:
# Checkout to correct ref
@@ -38,41 +39,37 @@ jobs:
# If ref is 'refs/heads/main' => set 'main'
# Else it must be a tag => set {tag}
- name: Set checkout_ref and path_in_repo
env:
EVENT_NAME: ${{ github.event_name }}
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
GITHUB_REF: ${{ github.ref }}
run: |
if [ "$EVENT_NAME" == "workflow_dispatch" ]; then
if [ -z "$EVENT_INPUT_REF" ]; then
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
if [ -z "${{ github.event.inputs.ref }}" ]; then
echo "Error: Missing ref input"
exit 1
elif [ "$EVENT_INPUT_REF" == "main" ]; then
elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV
echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
fi
elif [ "$GITHUB_REF" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
else
# e.g. refs/tags/v0.28.1 -> v0.28.1
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
fi
- name: Print env vars
run: |
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
- uses: actions/checkout@v6
- uses: actions/checkout@v3
with:
ref: ${{ env.CHECKOUT_REF }}
# Setup + install dependencies
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -102,4 +99,4 @@ jobs:
- name: Report failure status
if: ${{ failure() }}
run: |
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure

View File

@@ -28,7 +28,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -97,7 +97,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -119,7 +119,7 @@ jobs:
module: [models, schedulers, lora, others, single_file, examples]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -167,7 +167,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_${{ matrix.module }}_cuda_test_reports
path: reports
@@ -184,7 +184,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -228,7 +228,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -263,7 +263,7 @@ jobs:
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_big_gpu_test_reports
path: reports
@@ -280,7 +280,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -321,7 +321,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -355,7 +355,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -391,7 +391,7 @@ jobs:
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_${{ matrix.config.backend }}_reports
path: reports
@@ -408,7 +408,7 @@ jobs:
options: --shm-size "20gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -441,7 +441,7 @@ jobs:
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_pipeline_level_quant_reports
path: reports
@@ -466,7 +466,7 @@ jobs:
image: diffusers/diffusers-pytorch-cpu
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -474,7 +474,7 @@ jobs:
run: mkdir -p combined_reports
- name: Download all test reports
uses: actions/download-artifact@v7
uses: actions/download-artifact@v4
with:
path: artifacts
@@ -500,7 +500,7 @@ jobs:
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
- name: Upload consolidated report
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: consolidated_test_report
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
@@ -514,7 +514,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@v3
# with:
# fetch-depth: 2
#
@@ -554,7 +554,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@v4
# with:
# name: torch_mps_test_reports
# path: reports
@@ -570,7 +570,7 @@ jobs:
#
# steps:
# - name: Checkout diffusers
# uses: actions/checkout@v6
# uses: actions/checkout@v3
# with:
# fetch-depth: 2
#
@@ -610,7 +610,7 @@ jobs:
#
# - name: Test suite reports artifacts
# if: ${{ always() }}
# uses: actions/upload-artifact@v6
# uses: actions/upload-artifact@v4
# with:
# name: torch_mps_test_reports
# path: reports

View File

@@ -10,10 +10,10 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.8'

View File

@@ -18,9 +18,9 @@ jobs:
check_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -1,4 +1,3 @@
name: Fast PR tests for Modular
on:
@@ -36,9 +35,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -56,9 +55,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.10"
- name: Install dependencies
@@ -78,13 +77,23 @@ jobs:
run_fast_tests:
needs: [check_code_quality, check_repository_consistency]
name: Fast PyTorch Modular Pipeline CPU tests
strategy:
fail-fast: false
matrix:
config:
- name: Fast PyTorch Modular Pipeline CPU tests
framework: pytorch_pipelines
runner: aws-highmemory-32-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_cpu_modular_pipelines
name: ${{ matrix.config.name }}
runs-on:
group: aws-highmemory-32-plus
group: ${{ matrix.config.runner }}
container:
image: diffusers/diffusers-pytorch-cpu
image: ${{ matrix.config.image }}
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
defaults:
@@ -93,7 +102,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -109,19 +118,22 @@ jobs:
python utils/print_env.py
- name: Run fast PyTorch Pipeline CPU tests
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
run: |
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
-k "not Flax and not Onnx" \
--make-reports=tests_torch_cpu_modular_pipelines \
--make-reports=tests_${{ matrix.config.report }} \
tests/modular_pipelines
- name: Failure short reports
if: ${{ failure() }}
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -28,7 +28,7 @@ jobs:
test_map: ${{ steps.set_matrix.outputs.test_map }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 0
- name: Install dependencies
@@ -42,7 +42,7 @@ jobs:
run: |
python utils/tests_fetcher.py | tee test_preparation.txt
- name: Report fetched tests
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v3
with:
name: test_fetched
path: test_preparation.txt
@@ -83,7 +83,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -109,7 +109,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v3
with:
name: ${{ matrix.modules }}_test_reports
path: reports
@@ -138,7 +138,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -164,7 +164,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -31,9 +31,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -51,9 +51,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -108,7 +108,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -153,7 +153,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
path: reports
@@ -185,7 +185,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -211,7 +211,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports
@@ -236,7 +236,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -273,7 +273,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_main_test_reports
path: reports

View File

@@ -32,9 +32,9 @@ jobs:
check_code_quality:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -52,9 +52,9 @@ jobs:
needs: check_code_quality
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies
@@ -83,7 +83,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -100,7 +100,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -120,7 +120,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -170,7 +170,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -193,7 +193,7 @@ jobs:
module: [models, schedulers, lora, others]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -239,7 +239,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -255,7 +255,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -287,7 +287,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -18,9 +18,9 @@ jobs:
check_torch_dependencies:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"
- name: Install dependencies

View File

@@ -29,7 +29,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -46,7 +46,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -66,7 +66,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -98,7 +98,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -120,7 +120,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -155,7 +155,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_test_reports_${{ matrix.module }}
path: reports
@@ -172,7 +172,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -199,7 +199,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -216,7 +216,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -240,7 +240,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_xformers_test_reports
path: reports
@@ -256,7 +256,7 @@ jobs:
options: --gpus all --shm-size "16gb" --ipc host
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -286,7 +286,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -54,7 +54,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -88,7 +88,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_${{ matrix.config.report }}_test_reports
path: reports

View File

@@ -23,7 +23,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -65,7 +65,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pr_torch_mps_test_reports
path: reports

View File

@@ -15,10 +15,10 @@ jobs:
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
steps:
- name: Checkout Repo
uses: actions/checkout@v6
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: '3.8'
@@ -40,12 +40,12 @@ jobs:
steps:
- name: Checkout Repo
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v4
with:
python-version: "3.8"

View File

@@ -27,7 +27,7 @@ jobs:
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: Install dependencies
@@ -44,7 +44,7 @@ jobs:
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
- name: Pipeline Tests Artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: test-pipelines.json
path: reports
@@ -64,7 +64,7 @@ jobs:
options: --shm-size "16gb" --ipc host --gpus all
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
- name: NVIDIA-SMI
@@ -94,7 +94,7 @@ jobs:
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: pipeline_${{ matrix.module }}_test_reports
path: reports
@@ -116,7 +116,7 @@ jobs:
module: [models, schedulers, lora, others, single_file]
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -149,7 +149,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_cuda_${{ matrix.module }}_test_reports
path: reports
@@ -166,7 +166,7 @@ jobs:
shell: bash
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -205,7 +205,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_minimum_version_cuda_test_reports
path: reports
@@ -222,7 +222,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -247,7 +247,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_compile_test_reports
path: reports
@@ -264,7 +264,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -288,7 +288,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: torch_xformers_test_reports
path: reports
@@ -305,7 +305,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2
@@ -336,7 +336,7 @@ jobs:
- name: Test suite reports artifacts
if: ${{ always() }}
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v4
with:
name: examples_test_reports
path: reports

View File

@@ -57,7 +57,7 @@ jobs:
shell: bash -e {0}
- name: Checkout PR branch
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
ref: refs/pull/${{ inputs.pr_number }}/head

View File

@@ -27,7 +27,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2

View File

@@ -35,7 +35,7 @@ jobs:
steps:
- name: Checkout diffusers
uses: actions/checkout@v6
uses: actions/checkout@v3
with:
fetch-depth: 2

View File

@@ -15,10 +15,10 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v2
- name: Setup Python
uses: actions/setup-python@v6
uses: actions/setup-python@v1
with:
python-version: 3.8

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- name: Checkout code
uses: actions/checkout@v6
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Secret Scanning

View File

@@ -8,7 +8,7 @@ jobs:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: typos-action
uses: crate-ci/typos@v1.12.4

View File

@@ -15,7 +15,7 @@ jobs:
shell: bash -l {0}
steps:
- uses: actions/checkout@v6
- uses: actions/checkout@v3
- name: Setup environment
run: |

View File

@@ -54,8 +54,6 @@
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- local: hybrid_inference/overview
title: Remote inference
title: Inference
- isExpanded: false
sections:
@@ -90,6 +88,17 @@
title: FreeU
title: Community optimizations
title: Inference optimization
- isExpanded: false
sections:
- local: hybrid_inference/overview
title: Overview
- local: hybrid_inference/vae_decode
title: VAE Decode
- local: hybrid_inference/vae_encode
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
title: Hybrid Inference
- isExpanded: false
sections:
- local: modular_diffusers/overview
@@ -261,8 +270,6 @@
title: Outputs
- local: api/quantization
title: Quantization
- local: hybrid_inference/api_reference
title: Remote inference
- local: api/parallel
title: Parallel inference
title: Main Classes
@@ -346,8 +353,6 @@
title: Flux2Transformer2DModel
- local: api/models/flux_transformer
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -362,8 +367,6 @@
title: LatteTransformer3DModel
- local: api/models/longcat_image_transformer2d
title: LongCatImageTransformer2DModel
- local: api/models/ltx2_video_transformer3d
title: LTX2VideoTransformer3DModel
- local: api/models/ltx_video_transformer3d
title: LTXVideoTransformer3DModel
- local: api/models/lumina2_transformer2d
@@ -440,10 +443,6 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
title: AutoencoderKLLTX2Video
- local: api/models/autoencoderkl_ltx_video
title: AutoencoderKLLTXVideo
- local: api/models/autoencoderkl_magvit
@@ -542,8 +541,6 @@
title: Flux2
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/glm_image
title: GLM-Image
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
@@ -681,8 +678,6 @@
title: Kandinsky 5.0 Video
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx2
title: LTX-2
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi

View File

@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
[[autodoc]] apply_faster_cache
## FirstBlockCacheConfig
### FirstBlockCacheConfig
[[autodoc]] FirstBlockCacheConfig

View File

@@ -33,7 +33,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
> [!TIP]
@@ -63,10 +62,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
## LTX2LoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
## CogVideoXLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin

View File

@@ -1,29 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Audio
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Audio
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Audio
[[autodoc]] AutoencoderKLLTX2Audio
- encode
- decode
- all

View File

@@ -1,29 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# AutoencoderKLLTX2Video
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import AutoencoderKLLTX2Video
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
```
## AutoencoderKLLTX2Video
[[autodoc]] AutoencoderKLLTX2Video
- decode
- encode
- all

View File

@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
## FluxControlNetOutput
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
[[autodoc]] models.controlnet_flux.FluxControlNetOutput

View File

@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
## SparseControlNetOutput
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput

View File

@@ -1,18 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# GlmImageTransformer2DModel
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
## GlmImageTransformer2DModel
[[autodoc]] GlmImageTransformer2DModel

View File

@@ -1,26 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# LTX2VideoTransformer3DModel
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
The model can be loaded with the following code snippet.
```python
from diffusers import LTX2VideoTransformer3DModel
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
```
## LTX2VideoTransformer3DModel
[[autodoc]] LTX2VideoTransformer3DModel

View File

@@ -30,10 +30,6 @@
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
Available Models/LoRAs:
- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)
- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)
- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)
### Image Editing
@@ -104,7 +100,6 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
@@ -114,8 +109,9 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0)
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
pipe.load_lora_weights(lora_path)
pipe.fuse_lora(lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
@@ -149,57 +145,6 @@ export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
```
### Inference with Multiple LoRAs
```py
import torch
import numpy as np
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
from diffusers.schedulers import UniPCMultistepScheduler
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from PIL import Image
model_id = "nvidia/ChronoEdit-14B-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush")
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0)
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
pipe.to("cuda")
image = load_image(
"https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png"
)
max_area = 720 * 1280
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
print("width", width, "height", height)
image = image.resize((width, height))
prompt = (
"Turn the pencil sketch in the image into an actual object that is consistent with the images content. The user wants to change the sketch to a crown and a hat."
)
output = pipe(
image=image,
prompt=prompt,
height=height,
width=width,
num_frames=5,
num_inference_steps=8,
guidance_scale=1.0,
enable_temporal_reasoning=False,
num_temporal_reasoning_steps=0,
).frames[0]
export_to_video(output, "output.mp4", fps=16)
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png")
```
## ChronoEditPipeline
[[autodoc]] ChronoEditPipeline

View File

@@ -70,12 +70,6 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -21,7 +21,7 @@ The abstract from the paper is:
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️

View File

@@ -1,95 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-->
# GLM-Image
## Overview
GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
Model architecture: a hybrid autoregressive + diffusion decoder design、
+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K4K tokens, corresponding to 1K2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
GLM-Image supports both text-to-image and image-to-image generation within a single model
+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
## Usage examples
### Text to Image Generation
```python
import torch
from diffusers.pipelines.glm_image import GlmImagePipeline
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
image = pipe(
prompt=prompt,
height=32 * 32,
width=36 * 32,
num_inference_steps=30,
guidance_scale=1.5,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
image.save("output_t2i.png")
```
### Image to Image Generation
```python
import torch
from diffusers.pipelines.glm_image import GlmImagePipeline
from PIL import Image
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
image_path = "cond.jpg"
prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
image = Image.open(image_path).convert("RGB")
image = pipe(
prompt=prompt,
image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
height=33 * 32,
width=32 * 32,
num_inference_steps=30,
guidance_scale=1.5,
generator=torch.Generator(device="cuda").manual_seed(42),
).images[0]
image.save("output_i2i.png")
```
+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
## GlmImagePipeline
[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
- all
- __call__
## GlmImagePipelineOutput
[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput

View File

@@ -1,47 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# LTX-2
<div class="flex flex-wrap space-x-1">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</div>
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
## LTX2Pipeline
[[autodoc]] LTX2Pipeline
- all
- __call__
## LTX2ImageToVideoPipeline
[[autodoc]] LTX2ImageToVideoPipeline
- all
- __call__
## LTX2LatentUpsamplePipeline
[[autodoc]] LTX2LatentUpsamplePipeline
- all
- __call__
## LTX2PipelineOutput
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput

View File

@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
<details>
<summary>Show example code</summary>
```python
import torch
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
@@ -474,12 +474,6 @@ export_to_video(video, "output.mp4", fps=24)
</details>
## LTXI2VLongMultiPromptPipeline
[[autodoc]] LTXI2VLongMultiPromptPipeline
- all
- __call__
## LTXPipeline
[[autodoc]] LTXPipeline

View File

@@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png")
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
```py
```
import torch
from PIL import Image
from diffusers import QwenImageEditPlusPipeline
@@ -108,46 +108,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
image = pipe(
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
image=[image_1, image_2],
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
num_inference_steps=50
).images[0]
```
## Performance
### torch.compile
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
```python
import torch
from diffusers import QwenImagePipeline
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
pipe.transformer = torch.compile(pipe.transformer)
# First call triggers compilation (~7s overhead)
# Subsequent calls run at ~2.4x faster
image = pipe("a cat", num_inference_steps=50).images[0]
```
### Batched Inference with Variable-Length Prompts
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
```python
# CFG with different prompt lengths works correctly
image = pipe(
prompt="A cat",
negative_prompt="blurry, low quality, distorted",
true_cfg_scale=3.5,
num_inference_steps=50,
).images[0]
```
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
## QwenImagePipeline
[[autodoc]] QwenImagePipeline

View File

@@ -37,8 +37,7 @@ The following SkyReels-V2 models are supported in Diffusers:
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
> [!TIP]
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.

View File

@@ -250,6 +250,9 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
</hfoption>
</hfoptions>
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.

View File

@@ -1,11 +1,9 @@
# Remote inference
# Hybrid Inference API Reference
Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.
## remote_decode
## Remote Decode
[[autodoc]] utils.remote_utils.remote_decode
## remote_encode
## Remote Encode
[[autodoc]] utils.remote_utils.remote_encode

View File

@@ -10,296 +10,51 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# Remote inference
# Hybrid Inference
**Empowering local AI builders with Hybrid Inference**
> [!TIP]
> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.
| Model | Endpoint | Checkpoint | Support |
|---|---|---|---|
| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |
| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |
| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |
| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |
This guide will show you how to encode and decode latents with remote inference.
## Encoding
Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.
Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
```py
import torch
from diffusers import FluxPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_encode
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.float16,
vae=None,
device_map="cuda"
)
init_image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
image=init_image,
scaling_factor=0.3611,
shift_factor=0.1159
)
```
## Decoding
Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.
Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
<hfoptions id="decode">
<hfoption id="Flux">
```py
from diffusers import FluxPipeline
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
vae=None,
device_map="cuda"
)
prompt = """
A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.
"""
latent = pipeline(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=4,
output_type="latent",
).images
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("image.jpg")
```
</hfoption>
<hfoption id="HunyuanVideo">
```py
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
)
pipeline = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda"
)
latent = pipeline(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
output_type="latent",
).frames
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
output_type="mp4",
)
if isinstance(video, bytes):
with open("video.mp4", "wb") as f:
f.write(video)
```
</hfoption>
</hfoptions>
## Queuing
Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.
```py
import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionXLPipeline
def decode_worker(q: queue.Queue):
while True:
item = q.get()
if item is None:
break
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.13025,
)
display(image)
q.task_done()
q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()
def decode(latent: torch.Tensor):
q.put(latent)
prompts = [
"A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.",
"A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.",
"A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.",
"A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moons surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.",
"A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos."
]
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
pipeline = StableDiffusionXLPipeline.from_pretrained(
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
vae=None,
device_map="cuda"
)
pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)
pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
## Why use Hybrid Inference?
_ = pipeline(
prompt=prompts[0],
output_type="latent",
)
Hybrid Inference offers a fast and simple way to offload local generation requirements.
for prompt in prompts:
latent = pipeline(
prompt=prompt,
output_type="latent",
).images
decode(latent)
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
- 💰 **Cost Effective:** It's free! 🤑
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
q.put(None)
thread.join()
```
---
## Benchmarks
## Available Models
The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.
---
<details><summary>Encoding - Stable Diffusion v1.5</summary>
## Integrations
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
</details>
## Changelog
<details><summary>Encoding SDXL</summary>
- March 10 2025: Added VAE encode
- March 2 2025: Initial release with VAE decoding
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
## Contents
</details>
The documentation is organized into three sections:
<details><summary>Decoding - Stable Diffusion v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
</details>
<details><summary>Decoding SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
</details>
## Resources
- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).
- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more.
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
* **API Reference** Dive into task-specific settings and parameters.

View File

@@ -0,0 +1,345 @@
# Getting Started: VAE Decode with Hybrid Inference
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
## Memory
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
| --- | --- | --- | --- | --- | --- |
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_decode
```
### Basic example
Here, we show how to use the remote VAE on random tensors.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
scaling_factor=0.18215,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
</figure>
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
<details><summary>Code</summary>
```python
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
</figure>
Finally, an example for HunyuanVideo.
<details><summary>Code</summary>
```python
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
output_type="mp4",
)
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
</video>
</figure>
### Generation
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
<details><summary>Code</summary>
```python
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
</figure>
Heres another example with Flux.
<details><summary>Code</summary>
```python
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
vae=None,
).to("cuda")
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
latent = pipe(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=4,
output_type="latent",
).images
image = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
height=1024,
width=1024,
scaling_factor=0.3611,
shift_factor=0.1159,
)
image.save("test.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
</figure>
Heres an example with HunyuanVideo.
<details><summary>Code</summary>
```python
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
).to("cuda")
latent = pipe(
prompt="A cat walks on the grass, realistic",
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
output_type="latent",
).frames
video = remote_decode(
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
output_type="mp4",
)
if isinstance(video, bytes):
with open("video.mp4", "wb") as f:
f.write(video)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
</video>
</figure>
### Queueing
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
<details><summary>Code</summary>
```python
import queue
import threading
from IPython.display import display
from diffusers import StableDiffusionPipeline
def decode_worker(q: queue.Queue):
while True:
item = q.get()
if item is None:
break
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=item,
scaling_factor=0.18215,
)
display(image)
q.task_done()
q = queue.Queue()
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
thread.start()
def decode(latent: torch.Tensor):
q.put(latent)
prompts = [
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
]
pipe = StableDiffusionPipeline.from_pretrained(
"Lykon/dreamshaper-8",
torch_dtype=torch.float16,
vae=None,
).to("cuda")
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
_ = pipe(
prompt=prompts[0],
output_type="latent",
)
for prompt in prompts:
latent = pipe(
prompt=prompt,
output_type="latent",
).images
decode(latent)
q.put(None)
thread.join()
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<video
alt="queue.mp4"
autoplay loop autobuffer muted playsinline
>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
</video>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

View File

@@ -0,0 +1,183 @@
# Getting Started: VAE Encode with Hybrid Inference
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
## Memory
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
<details><summary>SD v1.5</summary>
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
</details>
<details><summary>SDXL</summary>
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
</details>
## Available VAEs
| | **Endpoint** | **Model** |
|:-:|:-----------:|:--------:|
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
> [!TIP]
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
## Code
> [!TIP]
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
A helper method simplifies interacting with Hybrid Inference.
```python
from diffusers.utils.remote_utils import remote_encode
```
### Basic example
Let's encode an image, then decode it to demonstrate.
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
</figure>
<details><summary>Code</summary>
```python
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
latent = remote_encode(
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
scaling_factor=0.3611,
shift_factor=0.1159,
)
decoded = remote_decode(
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.3611,
shift_factor=0.1159,
)
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
</figure>
### Generation
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
<details><summary>Code</summary>
```python
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers.utils import load_image
from diffusers.utils.remote_utils import remote_decode, remote_encode
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
torch_dtype=torch.float16,
variant="fp16",
vae=None,
).to("cuda")
init_image = load_image(
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
)
init_image = init_image.resize((768, 512))
init_latent = remote_encode(
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
image=init_image,
scaling_factor=0.18215,
)
prompt = "A fantasy landscape, trending on artstation"
latent = pipe(
prompt=prompt,
image=init_latent,
strength=0.75,
output_type="latent",
).images
image = remote_decode(
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
tensor=latent,
scaling_factor=0.18215,
)
image.save("fantasy_landscape.jpg")
```
</details>
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
</figure>
## Integrations
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.

View File

@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
type_hint=str,
required=True,
default="mask_image",
description="""Output type from annotation predictions. Available options are
description="""Output type from annotation predictions. Availabe options are
mask_image:
-black and white mask image for the given image based on the task type
mask_overlay:

View File

@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
- It receives the iteration variable from the loop wrapper.
- It recieves the iteration variable from the loop wrapper.
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].

View File

@@ -68,20 +68,6 @@ config = FasterCacheConfig(
pipeline.transformer.enable_cache(config)
```
## FirstBlockCache
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
```py
import torch
from diffusers import DiffusionPipeline
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
pipeline = DiffusionPipeline.from_pretrained(
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
)
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
```
## TaylorSeer Cache
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
@@ -101,7 +87,8 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
pipe = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
torch_dtype=torch.bfloat16,
).to("cuda")
)
pipe.to("cuda")
config = TaylorSeerCacheConfig(
cache_interval=5,
@@ -110,4 +97,4 @@ config = TaylorSeerCacheConfig(
taylor_factors_dtype=torch.bfloat16,
)
pipe.transformer.enable_cache(config)
```
```

View File

@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
)
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
quantization_config=pipeline_quant_config,
quantzation_config=pipeline_quant_config,
torch_dtype=torch.bfloat16,
device_map="cuda"
)

View File

@@ -263,8 +263,8 @@ def main():
world_size = dist.get_world_size()
pipeline = DiffusionPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to(device)
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
)
pipeline.transformer.set_attention_backend("_native_cudnn")
cp_config = ContextParallelConfig(ring_degree=world_size)
@@ -314,35 +314,6 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
```
### Unified Attention
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
This hybrid approach leverages the strengths of both methods:
- **Ulysses Attention** efficiently parallelizes across attention heads
- **Ring Attention** handles very long sequences with minimal memory overhead
- Together, they enable 2D parallelization across both heads and sequence dimensions
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
```py
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
```
> [!TIP]
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|--------------------|------------------|-------------|------------------|
| ulysses | 6670.789 | 7.50 | 33.85 |
| ring | 13076.492 | 3.82 | 56.02 |
| unified_balanced | 11068.705 | 4.52 | 33.85 |
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
### parallel_config
Pass `parallel_config` during model initialization to enable context parallelism.

View File

@@ -1929,8 +1929,6 @@ def main(args):
if args.cache_latents:
latents_cache = []
# Store vae config before potential deletion
vae_scaling_factor = vae.config.scaling_factor
for batch in tqdm(train_dataloader, desc="Caching latents"):
with torch.no_grad():
batch["pixel_values"] = batch["pixel_values"].to(
@@ -1942,8 +1940,6 @@ def main(args):
del vae
if torch.cuda.is_available():
torch.cuda.empty_cache()
else:
vae_scaling_factor = vae.config.scaling_factor
# Scheduler and math around the number of training steps.
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
@@ -2113,13 +2109,13 @@ def main(args):
model_input = vae.encode(pixel_values).latent_dist.sample()
if latents_mean is None and latents_std is None:
model_input = model_input * vae_scaling_factor
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
else:
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents

View File

@@ -149,13 +149,13 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_images",
type=str,
default=None,
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
)
parser.add_argument(
"--validation_prompt_separator",

View File

@@ -140,7 +140,7 @@ def get_args():
"--validation_prompt",
type=str,
default=None,
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
)
parser.add_argument(
"--validation_prompt_separator",

View File

@@ -21,8 +21,8 @@ from transformers import (
BertModel,
BertTokenizer,
CLIPImageProcessor,
MT5Tokenizer,
T5EncoderModel,
T5Tokenizer,
)
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
The HunyuanDiT model designed by Tencent Hunyuan.
text_encoder_2 (`T5EncoderModel`):
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
tokenizer_2 (`T5Tokenizer`):
tokenizer_2 (`MT5Tokenizer`):
The tokenizer for the mT5 embedder.
scheduler ([`DDPMScheduler`]):
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
feature_extractor: CLIPImageProcessor,
requires_safety_checker: bool = True,
text_encoder_2=T5EncoderModel,
tokenizer_2=T5Tokenizer,
tokenizer_2=MT5Tokenizer,
):
super().__init__()

View File

@@ -1,844 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import Any, Callable, Dict, List, Optional, Union
import torch
from transformers import AutoTokenizer, PreTrainedModel
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from diffusers.models.autoencoders import AutoencoderKL
from diffusers.models.transformers import ZImageTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import logging, replace_example_docstring
from diffusers.utils.torch_utils import randn_tensor
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline
>>> from diffusers.utils import load_image
>>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> init_image = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
>>> )
>>> mask = load_image(
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
>>> )
>>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
>>> image = pipe(
... prompt,
... image=init_image,
... mask_image=mask,
... strength=0.75,
... num_inference_steps=9,
... guidance_scale=0.0,
... generator=torch.Generator("cuda").manual_seed(41),
... ).images[0]
>>> image.save("image.png")
```
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
r"""
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
Args:
scheduler (`SchedulerMixin`):
The scheduler to get timesteps from.
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
must be `None`.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
timesteps (`List[int]`, *optional*):
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
`num_inference_steps` and `sigmas` must be `None`.
sigmas (`List[float]`, *optional*):
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
`num_inference_steps` and `timesteps` must be `None`.
Returns:
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
second element is the number of inference steps.
"""
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
r"""
The ZImage pipeline for image-to-image generation.
Args:
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`PreTrainedModel`]):
A text encoder model to encode text prompts.
tokenizer ([`AutoTokenizer`]):
A tokenizer to tokenize text prompts.
transformer ([`ZImageTransformer2DModel`]):
A ZImage transformer model to denoise the encoded image latents.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
def __init__(
self,
scheduler: FlowMatchEulerDiscreteScheduler,
vae: AutoencoderKL,
text_encoder: PreTrainedModel,
tokenizer: AutoTokenizer,
transformer: ZImageTransformer2DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
transformer=transformer,
)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
self.mask_processor = VaeImageProcessor(
vae_scale_factor=self.vae_scale_factor,
vae_latent_channels=latent_channels,
do_normalize=False,
do_binarize=False,
do_convert_grayscale=True,
)
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = self._encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
def _encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
device = device or self._execution_device
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
prompt_embeds = self.text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
return embeddings_list
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
def get_timesteps(self, num_inference_steps, strength, device):
# get the original timestep using init_timestep
init_timestep = min(num_inference_steps * strength, num_inference_steps)
t_start = int(max(num_inference_steps - init_timestep, 0))
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(t_start * self.scheduler.order)
return timesteps, num_inference_steps - t_start
@staticmethod
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
latent_image_ids = latent_image_ids.reshape(
latent_image_id_height * latent_image_id_width, latent_image_id_channels
)
return latent_image_ids.to(device=device, dtype=dtype)
def prepare_latents(
self,
image,
timestep,
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
if latents is not None:
return latents.to(device=device, dtype=dtype)
# Encode the input image
image = image.to(device=device, dtype=dtype)
if image.shape[1] != num_channels_latents:
if isinstance(generator, list):
image_latents = [
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
else:
image_latents = image
# Handle batch size expansion
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
additional_image_per_prompt = batch_size // image_latents.shape[0]
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
raise ValueError(
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
)
# Add noise using flow matching scale_noise
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
return latents, noise, image_latents, latent_image_ids
def prepare_mask_latents(
self,
mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
dtype,
device,
generator,
):
height = 2 * (int(height) // (self.vae_scale_factor * 2))
width = 2 * (int(width) // (self.vae_scale_factor * 2))
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(mask, size=(height, width))
mask = mask.to(device=device, dtype=dtype)
batch_size = batch_size * num_images_per_prompt
masked_image = masked_image.to(device=device, dtype=dtype)
if masked_image.shape[1] == 16:
masked_image_latents = masked_image
else:
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
if mask.shape[0] < batch_size:
if not batch_size % mask.shape[0] == 0:
raise ValueError(
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
" of masks that you pass is divisible by the total requested batch size."
)
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
if masked_image_latents.shape[0] < batch_size:
if not batch_size % masked_image_latents.shape[0] == 0:
raise ValueError(
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
" Make sure the number of images that you pass is divisible by the total requested batch size."
)
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
return mask, masked_image_latents
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs
@property
def num_timesteps(self):
return self._num_timesteps
@property
def interrupt(self):
return self._interrupt
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: PipelineImageInput = None,
mask_image: PipelineImageInput = None,
strength: float = 0.6,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 5.0,
cfg_normalization: bool = False,
cfg_truncation: float = 1.0,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
):
r"""
Function invoked when calling the pipeline for image-to-image generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask
are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
1)`, or `(H, W)`.
strength (`float`, *optional*, defaults to 0.6):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
essentially ignores `image`.
height (`int`, *optional*, defaults to 1024):
The height in pixels of the generated image. If not provided, uses the input image height.
width (`int`, *optional*, defaults to 1024):
The width in pixels of the generated image. If not provided, uses the input image width.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 5.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
cfg_normalization (`bool`, *optional*, defaults to False):
Whether to apply configuration normalization.
cfg_truncation (`float`, *optional*, defaults to 1.0):
The truncation value for configuration.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, *optional*, defaults to 512):
Maximum sequence length to use with the `prompt`.
Examples:
Returns:
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
generated images.
"""
# 1. Check inputs and validate strength
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
# 2. Preprocess image
init_image = self.image_processor.preprocess(image)
init_image = init_image.to(dtype=torch.float32)
# Get dimensions from the preprocessed image if not specified
if height is None:
height = init_image.shape[-2]
if width is None:
width = init_image.shape[-1]
vae_scale = self.vae_scale_factor * 2
if height % vae_scale != 0:
raise ValueError(
f"Height must be divisible by {vae_scale} (got {height}). "
f"Please adjust the height to a multiple of {vae_scale}."
)
if width % vae_scale != 0:
raise ValueError(
f"Width must be divisible by {vae_scale} (got {width}). "
f"Please adjust the width to a multiple of {vae_scale}."
)
device = self._execution_device
self._guidance_scale = guidance_scale
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
self._cfg_normalization = cfg_normalization
self._cfg_truncation = cfg_truncation
# 3. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
# If prompt_embeds is provided and prompt is None, skip encoding
if prompt_embeds is not None and prompt is None:
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided without `prompt`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
else:
(
prompt_embeds,
negative_prompt_embeds,
) = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
# 4. Prepare latent variables
num_channels_latents = self.transformer.in_channels
# Repeat prompt_embeds for num_images_per_prompt
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if self.do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
# Calculate latent dimensions for image_seq_len
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
image_seq_len = (latent_height // 2) * (latent_width // 2)
# 5. Prepare timesteps
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
self.scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler,
num_inference_steps,
device,
sigmas=sigmas,
**scheduler_kwargs,
)
# 6. Adjust timesteps based on strength
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
if num_inference_steps < 1:
raise ValueError(
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
)
latent_timestep = timesteps[:1].repeat(actual_batch_size)
# 7. Prepare latents from image
latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
init_image,
latent_timestep,
actual_batch_size,
num_channels_latents,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
latents,
)
resize_mode = "default"
crops_coords = None
# start diff diff preparation
original_mask = self.mask_processor.preprocess(
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
)
masked_image = init_image * original_mask
original_mask, _ = self.prepare_mask_latents(
original_mask,
masked_image,
batch_size,
num_images_per_prompt,
height,
width,
prompt_embeds[0].dtype,
device,
generator,
)
mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
masks = original_mask > mask_thresholds
# end diff diff preparation
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
self._num_timesteps = len(timesteps)
# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])
timestep = (1000 - timestep) / 1000
# Normalized time for time-aware config (0 at start, 1 at end)
t_norm = timestep[0].item()
# Handle cfg truncation
current_guidance_scale = self.guidance_scale
if (
self.do_classifier_free_guidance
and self._cfg_truncation is not None
and float(self._cfg_truncation) <= 1
):
if t_norm > self._cfg_truncation:
current_guidance_scale = 0.0
# Run CFG only if configured AND scale is non-zero
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
if apply_cfg:
latents_typed = latents.to(self.transformer.dtype)
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
timestep_model_input = timestep.repeat(2)
else:
latent_model_input = latents.to(self.transformer.dtype)
prompt_embeds_model_input = prompt_embeds
timestep_model_input = timestep
latent_model_input = latent_model_input.unsqueeze(2)
latent_model_input_list = list(latent_model_input.unbind(dim=0))
model_out_list = self.transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds_model_input,
)[0]
if apply_cfg:
# Perform CFG
pos_out = model_out_list[:actual_batch_size]
neg_out = model_out_list[actual_batch_size:]
noise_pred = []
for j in range(actual_batch_size):
pos = pos_out[j].float()
neg = neg_out[j].float()
pred = pos + current_guidance_scale * (pos - neg)
# Renormalization
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
ori_pos_norm = torch.linalg.vector_norm(pos)
new_pos_norm = torch.linalg.vector_norm(pred)
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
if new_pos_norm > max_new_norm:
pred = pred * (max_new_norm / new_pos_norm)
noise_pred.append(pred)
noise_pred = torch.stack(noise_pred, dim=0)
else:
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
# start diff diff
image_latent = original_image_latents
latents_dtype = latents.dtype
if i < len(timesteps) - 1:
noise_timestep = timesteps[i + 1]
image_latent = self.scheduler.scale_noise(
original_image_latents, torch.tensor([noise_timestep]), noise
)
mask = masks[i].to(latents_dtype)
latents = image_latent * mask + latents * (1 - mask)
# end diff diff
if latents.dtype != latents_dtype:
if torch.backends.mps.is_available():
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
latents = latents.to(latents_dtype)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if output_type == "latent":
image = latents
else:
latents = latents.to(self.vae.dtype)
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
image = self.vae.decode(latents, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return ZImagePipelineOutput(images=image)

View File

@@ -98,9 +98,6 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
This way, the text encoder model is not loaded into memory during training.
> [!NOTE]
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
### FSDP Text Encoder
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
This way, it distributes the memory cost across multiple nodes.
### CPU Offloading
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
### Latent Caching
@@ -169,26 +166,6 @@ To better track our training experiments, we're using the following flags in the
> [!NOTE]
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
### FSDP on the transformer
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
```shell
distributed_type: FSDP
fsdp_config:
fsdp_version: 2
fsdp_offload_params: false
fsdp_sharding_strategy: HYBRID_SHARD
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
fsdp_forward_prefetch: true
fsdp_sync_module_states: false
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_use_orig_params: false
fsdp_activation_checkpointing: true
fsdp_reshard_after_forward: true
fsdp_cpu_ram_efficient_loading: false
```
## LoRA + DreamBooth
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.

View File

@@ -111,25 +111,6 @@ To better track our training experiments, we're using the following flags in the
## Notes
### LoRA Rank and Alpha
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
- lora_alpha vs. rank:
This ratio dictates the LoRA's effective strength:
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
> [!TIP]
> A common starting point is to set `lora_alpha` equal to `rank`.
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
> to give the LoRA updates more influence without increasing parameter count.
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
### Additional CLI arguments
Additionally, we welcome you to explore the following CLI arguments:
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.

View File

@@ -44,7 +44,6 @@ import shutil
import warnings
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -76,16 +75,13 @@ from diffusers import (
from diffusers.optimization import get_scheduler
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -729,7 +722,6 @@ def parse_args(input_args=None):
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1227,11 +1219,7 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1275,42 +1263,17 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1322,20 +1285,13 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not is_fsdp:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1551,21 +1507,6 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1595,8 +1536,6 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1838,7 +1777,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or is_fsdp:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1897,41 +1836,15 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -43,7 +43,6 @@ import random
import shutil
from contextlib import nullcontext
from pathlib import Path
from typing import Any
import numpy as np
import torch
@@ -75,16 +74,13 @@ from diffusers.optimization import get_scheduler
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
from diffusers.training_utils import (
_collate_lora_metadata,
_to_cpu_contiguous,
cast_training_params,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
find_nearest_bucket,
free_memory,
get_fsdp_kwargs_from_accelerator,
offload_models,
parse_buckets_string,
wrap_with_fsdp,
)
from diffusers.utils import (
check_min_version,
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import is_compiled_module
if getattr(torch, "distributed", None) is not None:
import torch.distributed as dist
if is_wandb_available():
import wandb
@@ -346,7 +339,7 @@ def parse_args(input_args=None):
"--instance_prompt",
type=str,
default=None,
required=False,
required=True,
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
)
parser.add_argument(
@@ -698,7 +691,6 @@ def parse_args(input_args=None):
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
if input_args is not None:
args = parser.parse_args(input_args)
@@ -835,28 +827,15 @@ class DreamBoothDataset(Dataset):
dest_image = self.cond_images[i]
image_width, image_height = dest_image.size
if image_width * image_height > 1024 * 1024:
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
image_width, image_height = dest_image.size
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
image_width = (image_width // multiple_of) * multiple_of
image_height = (image_height // multiple_of) * multiple_of
image_processor = Flux2ImageProcessor()
dest_image = image_processor.preprocess(
dest_image = Flux2ImageProcessor.image_processor.preprocess(
dest_image, height=image_height, width=image_width, resize_mode="crop"
)
# Convert back to PIL
dest_image = dest_image.squeeze(0)
if dest_image.min() < 0:
dest_image = (dest_image + 1) / 2
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
if dest_image.shape[0] == 1:
# Gray scale image
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
else:
# RGB scale image: (C, H, W) -> (H, W, C)
dest_image = TF.to_pil_image(dest_image)
dest_image = exif_transpose(dest_image)
if not dest_image.mode == "RGB":
@@ -1177,11 +1156,7 @@ def main(args):
if args.bnb_quantization_config_path is not None
else {"device": accelerator.device, "dtype": weight_dtype}
)
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
if not is_fsdp:
transformer.to(**transformer_to_kwargs)
transformer.to(**transformer_to_kwargs)
if args.do_fp8_training:
convert_to_float8_training(
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
@@ -1225,42 +1200,17 @@ def main(args):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
transformer_cls = type(unwrap_model(transformer))
# 1) Validate and pick the transformer model
modules_to_save: dict[str, Any] = {}
transformer_model = None
for model in models:
if isinstance(unwrap_model(model), transformer_cls):
transformer_model = model
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if transformer_model is None:
raise ValueError("No transformer model found in 'models'")
# 2) Optionally gather FSDP state dict once
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
# 3) Only main process materializes the LoRA state dict
transformer_lora_layers_to_save = None
if accelerator.is_main_process:
peft_kwargs = {}
if is_fsdp:
peft_kwargs["state_dict"] = state_dict
transformer_lora_layers_to_save = None
modules_to_save = {}
for model in models:
if isinstance(model, type(unwrap_model(transformer))):
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
modules_to_save["transformer"] = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
transformer_lora_layers_to_save = get_peft_model_state_dict(
unwrap_model(transformer_model) if is_fsdp else transformer_model,
**peft_kwargs,
)
if is_fsdp:
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
# make sure to pop weight so that corresponding model is not saved again
if weights:
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
Flux2Pipeline.save_lora_weights(
@@ -1272,20 +1222,13 @@ def main(args):
def load_model_hook(models, input_dir):
transformer_ = None
if not is_fsdp:
while len(models) > 0:
model = models.pop()
while len(models) > 0:
model = models.pop()
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
else:
transformer_ = Flux2Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="transformer",
)
transformer_.add_adapter(transformer_lora_config)
if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
@@ -1476,9 +1419,9 @@ def main(args):
args.instance_prompt, text_encoding_pipeline
)
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.validation_prompt is not None:
validation_image = load_image(args.validation_image_path).convert("RGB")
validation_kwargs = {"image": validation_image}
if args.remote_text_encoder:
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
else:
@@ -1487,21 +1430,6 @@ def main(args):
args.validation_prompt, text_encoding_pipeline
)
# Init FSDP for text encoder
if args.fsdp_text_encoder:
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
text_encoder_fsdp = wrap_with_fsdp(
model=text_encoding_pipeline.text_encoder,
device=accelerator.device,
offload=args.offload,
limit_all_gathers=True,
use_orig_params=True,
fsdp_kwargs=fsdp_kwargs,
)
text_encoding_pipeline.text_encoder = text_encoder_fsdp
dist.barrier()
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
# have to pass them to the dataloader.
@@ -1533,8 +1461,6 @@ def main(args):
if train_dataset.custom_instance_prompts:
if args.remote_text_encoder:
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
elif args.fsdp_text_encoder:
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
else:
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
@@ -1695,13 +1621,9 @@ def main(args):
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
device=cond_model_input.device
)
cond_model_input_ids = cond_model_input_ids.view(
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1728,9 +1650,6 @@ def main(args):
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
orig_input_shape = packed_noisy_model_input.shape
orig_input_ids_shape = model_input_ids.shape
# concatenate the model inputs with the cond inputs
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
@@ -1749,8 +1668,7 @@ def main(args):
img_ids=model_input_ids, # B, image_seq_len, 4
return_dict=False,
)[0]
model_pred = model_pred[:, : orig_input_shape[1], :]
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
@@ -1782,7 +1700,7 @@ def main(args):
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process or is_fsdp:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
@@ -1841,41 +1759,15 @@ def main(args):
# Save the lora layers
accelerator.wait_for_everyone()
if is_fsdp:
transformer = unwrap_model(transformer)
state_dict = accelerator.get_state_dict(transformer)
if accelerator.is_main_process:
modules_to_save = {}
if is_fsdp:
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
state_dict = {
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
else:
state_dict = {
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
}
transformer_lora_layers = get_peft_model_state_dict(
transformer,
state_dict=state_dict,
)
transformer_lora_layers = {
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
for k, v in transformer_lora_layers.items()
}
else:
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
transformer = unwrap_model(transformer)
if args.bnb_quantization_config_path is None:
if args.upcast_before_saving:
transformer.to(torch.float32)
else:
transformer = transformer.to(weight_dtype)
transformer_lora_layers = get_peft_model_state_dict(transformer)
modules_to_save["transformer"] = transformer
Flux2Pipeline.save_lora_weights(

View File

@@ -1513,12 +1513,14 @@ def main(args):
height=model_input.shape[3],
width=model_input.shape[4],
)
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
model_pred = transformer(
hidden_states=packed_noisy_model_input,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_mask=prompt_embeds_mask,
timestep=timesteps / 1000,
img_shapes=img_shapes,
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
return_dict=False,
)[0]
model_pred = QwenImagePipeline._unpack_latents(

View File

@@ -1,157 +0,0 @@
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
## Overview
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
- Loss of fine details in generated images
- Inconsistent image quality
- Structural artifacts
- Reduced sharpness and realism
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
- Improved image quality and consistency (6-20% FID improvement)
- Better preservation of fine details
- More stable training, especially at high noise levels
- Better handling of structural information
- Sharper and more realistic textures
## Implementation Details
The LPL implementation follows the paper's methodology and includes several key features:
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
- Middle block features
- Up block features (configurable number of blocks)
- Proper gradient checkpointing for memory efficiency
- Features are extracted only for timesteps below the threshold (high SNR)
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
- `default`: Normalize each feature map independently
- `shared`: Cross-normalize features using target statistics (recommended)
- `batch`: Batch-wise normalization
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
- Quantile-based filtering (2% quantiles)
- Morphological operations (opening/closing)
- Adaptive thresholding based on standard deviation
4. **Loss Types**:
- MSE loss (default)
- L1 loss
- Optional power law weighting (2^(-i) for layer i)
## Usage
To use LPL in your training, add the following arguments to your training command:
```bash
python examples/research_projects/lpl/train_sdxl_lpl.py \
--use_lpl \
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
--lpl_pow_law \ # Use power law weighting for layers
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
--lpl_remove_outliers \ # Remove outliers in feature maps
--lpl_scale \ # Scale LPL loss by noise level weights
--lpl_start 0 \ # Step to start applying LPL
# ... other training arguments ...
```
### Key Parameters
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
## Recommendations
1. **Starting Point** (based on paper results):
```bash
--use_lpl \
--lpl_weight 1.0 \
--lpl_t_threshold 200 \
--lpl_loss_type mse \
--lpl_norm_type shared \
--lpl_pow_law \
--lpl_num_blocks 4 \
--lpl_remove_outliers \
--lpl_scale
```
2. **Memory Efficiency**:
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
- Consider using `--lpl_scale` to focus on more important timesteps
- Features are extracted only for timesteps below threshold to save memory
3. **Quality vs Speed**:
- Higher `lpl_weight` (1.0-2.0) for better quality
- Lower `lpl_t_threshold` (100-200) for faster training
- Use `lpl_remove_outliers` for more stable training
- `lpl_norm_type shared` provides best quality/speed trade-off
## Technical Details
### Feature Extraction
The LPL implementation extracts features from the VAE decoder in the following order:
1. Middle block output
2. Up block outputs (configurable number of blocks)
Each feature map is processed with:
1. Optional outlier removal (2% quantiles, morphological operations)
2. Feature normalization (shared statistics recommended)
3. Loss calculation (MSE or L1)
4. Optional power law weighting (2^(-i) for layer i)
### Loss Calculation
For each feature map:
1. Features are normalized according to the chosen strategy
2. Loss is calculated between normalized features
3. Outliers are masked out (if enabled)
4. Loss is weighted by layer depth (if power law enabled)
5. Final loss is averaged across all layers
### Memory Considerations
- Gradient checkpointing is used by default
- Features are extracted only for timesteps below the threshold
- Outlier removal is done in-place to save memory
- Feature normalization is done efficiently using vectorized operations
- Memory usage scales linearly with number of blocks used
## Results
Based on the paper's findings, LPL provides:
- 6-20% improvement in FID scores
- Better preservation of fine details
- More realistic textures and structures
- Improved consistency across different resolutions
- Better performance on both small and large datasets
## Citation
If you use this implementation in your research, please cite:
```bibtex
@inproceedings{berrada2025boosting,
title={Boosting Latent Diffusion with Perceptual Objectives},
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=y4DtzADzd1}
}
```

View File

@@ -1,215 +0,0 @@
# Copyright 2025 Berrada et al.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def normalize_tensor(in_feat, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
return in_feat / (norm_factor + eps)
def cross_normalize(input, target, eps=1e-10):
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
return input / (norm_factor + eps), target / (norm_factor + eps)
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
opening = int(np.ceil(opening / down_f))
closing = int(np.ceil(closing / down_f))
if opening == 2:
opening = 3
if closing == 2:
closing = 1
# replace quantile with kth value here.
feat_flat = feat.flatten(-2, -1)
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
m = 2 * feat_flat.std(-1)[..., None, None].detach()
mask = (q1 - m < feat) * (feat < q2 + m)
# dilate the mask.
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
feat = feat * mask
return mask, feat
class LatentPerceptualLoss(nn.Module):
def __init__(
self,
vae,
loss_type="mse",
grad_ckpt=True,
pow_law=False,
norm_type="default",
num_mid_blocks=4,
feature_type="feature",
remove_outliers=True,
):
super().__init__()
self.vae = vae
self.decoder = self.vae.decoder
# Store scaling factors as tensors on the correct device
device = next(self.vae.parameters()).device
# Get scaling factors with proper defaults and handle None values
scale_factor = getattr(self.vae.config, "scaling_factor", None)
shift_factor = getattr(self.vae.config, "shift_factor", None)
# Convert to tensors with proper defaults
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
self.gradient_checkpointing = grad_ckpt
self.pow_law = pow_law
self.norm_type = norm_type.lower()
self.outlier_mask = remove_outliers
self.last_feature_stats = [] # Store feature statistics for logging
assert feature_type in ["feature", "image"]
self.feature_type = feature_type
assert self.norm_type in ["default", "shared", "batch"]
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
self.n_blocks = num_mid_blocks
assert loss_type in ["mse", "l1"]
if loss_type == "mse":
self.loss_fn = nn.MSELoss(reduction="none")
elif loss_type == "l1":
self.loss_fn = nn.L1Loss(reduction="none")
def get_features(self, z, latent_embeds=None, disable_grads=False):
with torch.set_grad_enabled(not disable_grads):
if self.gradient_checkpointing and not disable_grads:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
features = []
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
sample = z
sample = self.decoder.conv_in(sample)
# middle
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.decoder.mid_block),
sample,
latent_embeds,
use_reentrant=False,
)
sample = sample.to(upscale_dtype)
features.append(sample)
# up
for up_block in self.decoder.up_blocks[: self.n_blocks]:
sample = torch.utils.checkpoint.checkpoint(
create_custom_forward(up_block),
sample,
latent_embeds,
use_reentrant=False,
)
features.append(sample)
return features
else:
features = []
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
sample = z
sample = self.decoder.conv_in(sample)
# middle
sample = self.decoder.mid_block(sample, latent_embeds)
sample = sample.to(upscale_dtype)
features.append(sample)
# up
for up_block in self.decoder.up_blocks[: self.n_blocks]:
sample = up_block(sample, latent_embeds)
features.append(sample)
return features
def get_loss(self, input, target, get_hist=False):
if self.feature_type == "feature":
inp_f = self.get_features(self.shift + input / self.scale)
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
losses = []
self.last_feature_stats = [] # Reset feature stats
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
my = torch.ones_like(y).bool()
outlier_ratio = 0.0
if self.outlier_mask:
with torch.no_grad():
if i == 2:
my, y = remove_outliers(y, down_f=2)
outlier_ratio = 1.0 - my.float().mean().item()
elif i in [3, 4, 5]:
my, y = remove_outliers(y, down_f=1)
outlier_ratio = 1.0 - my.float().mean().item()
# Store feature statistics before normalization
with torch.no_grad():
stats = {
"mean": y.mean().item(),
"std": y.std().item(),
"outlier_ratio": outlier_ratio,
}
self.last_feature_stats.append(stats)
# normalize feature tensors
if self.norm_type == "default":
x = normalize_tensor(x)
y = normalize_tensor(y)
elif self.norm_type == "shared":
x, y = cross_normalize(x, y, eps=1e-6)
term_loss = self.loss_fn(x, y) * my
# reduce loss term
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
losses.append(term_loss.mean((1,)))
if get_hist:
return losses
else:
loss = sum(losses)
return loss / len(inp_f)
elif self.feature_type == "image":
inp_f = self.vae.decode(input / self.scale).sample
tar_f = self.vae.decode(target / self.scale).sample
return F.mse_loss(inp_f, tar_f)
def get_first_conv(self, z):
sample = self.decoder.conv_in(z)
return sample
def get_first_block(self, z):
sample = self.decoder.conv_in(z)
sample = self.decoder.mid_block(sample)
for resnet in self.decoder.up_blocks[0].resnets:
sample = resnet(sample, None)
return sample
def get_first_layer(self, input, target, target_layer="conv"):
if target_layer == "conv":
feat_in = self.get_first_conv(input)
with torch.no_grad():
feat_tar = self.get_first_conv(target)
else:
feat_in = self.get_first_block(input)
with torch.no_grad():
feat_tar = self.get_first_block(target)
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
return F.mse_loss(feat_in, feat_tar, reduction="mean")

File diff suppressed because it is too large Load Diff

View File

@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
___Note___:
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
## Running locally with PyTorch

View File

@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
NUM_DEVICES = jax.device_count()
# 1. Let's start by downloading the model and loading it into our pipeline class
# Adhering to JAX's functional approach, the model's parameters are returned separately and
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
# will have to be passed to the pipeline during inference
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True

View File

@@ -7,12 +7,16 @@ import torch
from diffusers.utils import logging
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
logger = logging.get_logger(__name__)
def safe_tokenize(tokenizer, *args, lock, **kwargs):
with lock:
return tokenizer(*args, **kwargs)
class RequestScopedPipeline:
DEFAULT_MUTABLE_ATTRS = [
"_all_hooks",
@@ -34,40 +38,23 @@ class RequestScopedPipeline:
wrap_scheduler: bool = True,
):
self._base = pipeline
self.unet = getattr(pipeline, "unet", None)
self.vae = getattr(pipeline, "vae", None)
self.text_encoder = getattr(pipeline, "text_encoder", None)
self.components = getattr(pipeline, "components", None)
self.transformer = getattr(pipeline, "transformer", None)
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
self._vae_lock = threading.Lock()
self._image_lock = threading.Lock()
self._auto_detect_mutables = bool(auto_detect_mutables)
self._tensor_numel_threshold = int(tensor_numel_threshold)
self._auto_detected_attrs: List[str] = []
def _detect_kernel_pipeline(self, pipeline) -> bool:
kernel_indicators = [
"text_encoding_cache",
"memory_manager",
"enable_optimizations",
"_create_request_context",
"get_optimization_stats",
]
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
base_sched = getattr(self._base, "scheduler", None)
if base_sched is None:
@@ -83,21 +70,11 @@ class RequestScopedPipeline:
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
)
except Exception as e:
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
try:
if hasattr(wrapped_scheduler, "scheduler"):
try:
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception:
return wrapped_scheduler
else:
copied_scheduler = copy.copy(wrapped_scheduler)
return BaseAsyncScheduler(copied_scheduler)
except Exception as e2:
logger.warning(
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
)
return copy.deepcopy(wrapped_scheduler)
except Exception as e:
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
return wrapped_scheduler
def _autodetect_mutables(self, max_attrs: int = 40):
@@ -109,7 +86,6 @@ class RequestScopedPipeline:
candidates: List[str] = []
seen = set()
for name in dir(self._base):
if name.startswith("__"):
continue
@@ -117,7 +93,6 @@ class RequestScopedPipeline:
continue
if name in ("to", "save_pretrained", "from_pretrained"):
continue
try:
val = getattr(self._base, name)
except Exception:
@@ -125,9 +100,11 @@ class RequestScopedPipeline:
import types
# skip callables and modules
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
continue
# containers -> candidate
if isinstance(val, (dict, list, set, tuple, bytearray)):
candidates.append(name)
seen.add(name)
@@ -228,9 +205,6 @@ class RequestScopedPipeline:
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
def _should_wrap_tokenizers(self) -> bool:
return True
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
@@ -240,25 +214,6 @@ class RequestScopedPipeline:
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
local_pipe = copy.deepcopy(self._base)
try:
if (
hasattr(local_pipe, "vae")
and local_pipe.vae is not None
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
):
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
if (
hasattr(local_pipe, "image_processor")
and local_pipe.image_processor is not None
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
):
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
local_pipe.image_processor, self._image_lock
)
except Exception as e:
logger.debug(f"Could not wrap vae/image_processor: {e}")
if local_scheduler is not None:
try:
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
@@ -276,42 +231,47 @@ class RequestScopedPipeline:
self._clone_mutable_attrs(self._base, local_pipe)
original_tokenizers = {}
# 4) wrap tokenizers on the local pipe with the lock wrapper
tokenizer_wrappers = {} # name -> original_tokenizer
try:
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
tokenizer_wrappers[name] = tok
setattr(
local_pipe,
name,
lambda *args, tok=tok, **kwargs: safe_tokenize(
tok, *args, lock=self._tokenizer_lock, **kwargs
),
)
if self._should_wrap_tokenizers():
try:
for name in dir(local_pipe):
if "tokenizer" in name and not name.startswith("_"):
tok = getattr(local_pipe, name, None)
if tok is not None and self._is_tokenizer_component(tok):
if not isinstance(tok, ThreadSafeTokenizerWrapper):
original_tokenizers[name] = tok
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
setattr(local_pipe, name, wrapped_tokenizer)
# b) wrap tokenizers in components dict
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
for key, val in local_pipe.components.items():
if val is None:
continue
if self._is_tokenizer_component(val):
tokenizer_wrappers[f"components[{key}]"] = val
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
)
if self._is_tokenizer_component(val):
if not isinstance(val, ThreadSafeTokenizerWrapper):
original_tokenizers[f"components[{key}]"] = val
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
local_pipe.components[key] = wrapped_tokenizer
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
except Exception as e:
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
result = None
cm = getattr(local_pipe, "model_cpu_offload_context", None)
try:
if callable(cm):
try:
with cm():
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
except TypeError:
# cm might be a context manager instance rather than callable
try:
with cm:
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
@@ -319,18 +279,18 @@ class RequestScopedPipeline:
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
else:
# no offload context available — call directly
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
return result
finally:
try:
for name, tok in original_tokenizers.items():
for name, tok in tokenizer_wrappers.items():
if name.startswith("components["):
key = name[len("components[") : -1]
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
local_pipe.components[key] = tok
local_pipe.components[key] = tok
else:
setattr(local_pipe, name, tok)
except Exception as e:
logger.debug(f"Error restoring original tokenizers: {e}")
logger.debug(f"Error restoring wrapped tokenizers: {e}")

View File

@@ -1,86 +0,0 @@
class ThreadSafeTokenizerWrapper:
def __init__(self, tokenizer, lock):
self._tokenizer = tokenizer
self._lock = lock
self._thread_safe_methods = {
"__call__",
"encode",
"decode",
"tokenize",
"encode_plus",
"batch_encode_plus",
"batch_decode",
}
def __getattr__(self, name):
attr = getattr(self._tokenizer, name)
if name in self._thread_safe_methods and callable(attr):
def wrapped_method(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped_method
return attr
def __call__(self, *args, **kwargs):
with self._lock:
return self._tokenizer(*args, **kwargs)
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._tokenizer, name, value)
def __dir__(self):
return dir(self._tokenizer)
class ThreadSafeVAEWrapper:
def __init__(self, vae, lock):
self._vae = vae
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._vae, name)
if name in {"decode", "encode", "forward"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._vae, name, value)
class ThreadSafeImageProcessorWrapper:
def __init__(self, proc, lock):
self._proc = proc
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._proc, name)
if name in {"postprocess", "preprocess"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._proc, name, value)

View File

@@ -1,94 +1,11 @@
"""
# Cosmos 2 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2-2B-Text2Image
```
convert checkpoint
```bash
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_ckpt_path $transformer_ckpt_path \
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
--text_encoder_path google-t5/t5-11b \
--tokenizer_path google-t5/t5-11b \
--vae_type wan2.1 \
--output_path converted/cosmos-p2-t2i-2b \
--save_pipeline
```
# Cosmos 2.5 Predict
Download checkpoint
```bash
hf download nvidia/Cosmos-Predict2.5-2B
```
Convert checkpoint
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
--save_pipeline
```
## 14B
```bash
hf download nvidia/Cosmos-Predict2.5-14B
```
```bash
# pre-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
--save_pipeline
# post-trained
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Predict-Base-14B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
--save_pipeline
```
"""
import argparse
import pathlib
import sys
from typing import Any, Dict
import torch
from accelerate import init_empty_weights
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
from transformers import T5EncoderModel, T5TokenizerFast
from diffusers import (
AutoencoderKLCosmos,
@@ -100,9 +17,7 @@ from diffusers import (
CosmosVideoToWorldPipeline,
EDMEulerScheduler,
FlowMatchEulerDiscreteScheduler,
UniPCMultistepScheduler,
)
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -318,44 +233,6 @@ TRANSFORMER_CONFIGS = {
"concat_padding_mask": True,
"extra_pos_embed_type": None,
},
"Cosmos-2.5-Predict-Base-2B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 16,
"attention_head_dim": 128,
"num_layers": 28,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
"Cosmos-2.5-Predict-Base-14B": {
"in_channels": 16 + 1,
"out_channels": 16,
"num_attention_heads": 40,
"attention_head_dim": 128,
"num_layers": 36,
"mlp_ratio": 4.0,
"text_embed_dim": 1024,
"adaln_lora_dim": 256,
"max_size": (128, 240, 240),
"patch_size": (1, 2, 2),
"rope_scale": (1.0, 3.0, 3.0),
"concat_padding_mask": True,
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
"extra_pos_embed_type": None,
"use_crossattn_projection": True,
"crossattn_proj_in_channels": 100352,
"encoder_hidden_states_channels": 1024,
},
}
VAE_KEYS_RENAME_DICT = {
@@ -457,9 +334,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
elif "Cosmos-2.0" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
elif "Cosmos-2.5" in transformer_type:
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
else:
assert False
@@ -473,7 +347,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
new_key = new_key.removeprefix(PREFIX_KEY)
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)
print(key, "->", new_key, flush=True)
update_state_dict_(original_state_dict, key, new_key)
for key in list(original_state_dict.keys()):
@@ -482,21 +355,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
continue
handler_fn_inplace(key, original_state_dict)
expected_keys = set(transformer.state_dict().keys())
mapped_keys = set(original_state_dict.keys())
missing_keys = expected_keys - mapped_keys
unexpected_keys = mapped_keys - expected_keys
if missing_keys:
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
for k in missing_keys:
print(k)
sys.exit(1)
if unexpected_keys:
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
for k in unexpected_keys:
print(k)
sys.exit(2)
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
return transformer
@@ -586,34 +444,6 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def save_pipeline_cosmos2_5(args, transformer, vae):
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
text_encoder_path, torch_dtype="auto", device_map="cpu"
)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
scheduler = UniPCMultistepScheduler(
use_karras_sigmas=True,
use_flow_sigmas=True,
prediction_type="flow_prediction",
sigma_max=200.0,
sigma_min=0.01,
)
pipe = Cosmos2_5_PredictBasePipeline(
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
vae=vae,
scheduler=scheduler,
safety_checker=lambda *args, **kwargs: None,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
@@ -621,10 +451,10 @@ def get_args():
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
)
parser.add_argument(
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
)
parser.add_argument("--text_encoder_path", type=str, default=None)
parser.add_argument("--tokenizer_path", type=str, default=None)
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
parser.add_argument("--save_pipeline", action="store_true")
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -647,6 +477,8 @@ if __name__ == "__main__":
if args.save_pipeline:
assert args.transformer_ckpt_path is not None
assert args.vae_type is not None
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
if args.transformer_ckpt_path is not None:
weights_only = "Cosmos-1.0" in args.transformer_type
@@ -658,26 +490,17 @@ if __name__ == "__main__":
if args.vae_type is not None:
if "Cosmos-1.0" in args.transformer_type:
vae = convert_vae(args.vae_type)
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
else:
vae = AutoencoderKLWan.from_pretrained(
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
)
else:
raise AssertionError(f"{args.transformer_type} not supported")
if not args.save_pipeline:
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.save_pipeline:
if "Cosmos-1.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_1_0(args, transformer, vae)
elif "Cosmos-2.0" in args.transformer_type:
assert args.text_encoder_path is not None
assert args.tokenizer_path is not None
save_pipeline_cosmos_2_0(args, transformer, vae)
elif "Cosmos-2.5" in args.transformer_type:
save_pipeline_cosmos2_5(args, transformer, vae)
else:
raise AssertionError(f"{args.transformer_type} not supported")
assert False

View File

@@ -1,886 +0,0 @@
import argparse
import os
from contextlib import nullcontext
from typing import Any, Dict, Optional, Tuple
import safetensors.torch
import torch
from accelerate import init_empty_weights
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
from diffusers import (
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
FlowMatchEulerDiscreteScheduler,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTX2VideoTransformer3DModel,
)
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
from diffusers.utils.import_utils import is_accelerate_available
CTX = init_empty_weights if is_accelerate_available() else nullcontext
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
# Input Patchify Projections
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
# Modulation Parameters
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
# substrings of the other modulation parameters below
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
# Transformer Blocks
# Per-Block Cross Attention Modulatin Parameters
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
# Encoder
"down_blocks.0": "down_blocks.0",
"down_blocks.1": "down_blocks.0.downsamplers.0",
"down_blocks.2": "down_blocks.1",
"down_blocks.3": "down_blocks.1.downsamplers.0",
"down_blocks.4": "down_blocks.2",
"down_blocks.5": "down_blocks.2.downsamplers.0",
"down_blocks.6": "down_blocks.3",
"down_blocks.7": "down_blocks.3.downsamplers.0",
"down_blocks.8": "mid_block",
# Decoder
"up_blocks.0": "mid_block",
"up_blocks.1": "up_blocks.0.upsamplers.0",
"up_blocks.2": "up_blocks.0",
"up_blocks.3": "up_blocks.1.upsamplers.0",
"up_blocks.4": "up_blocks.1",
"up_blocks.5": "up_blocks.2.upsamplers.0",
"up_blocks.6": "up_blocks.2",
# Common
# For all 3D ResNets
"res_blocks": "resnets",
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
"per_channel_statistics.mean-of-means": "latents_mean",
"per_channel_statistics.std-of-means": "latents_std",
}
LTX_2_0_VOCODER_RENAME_DICT = {
"ups": "upsamplers",
"resblocks": "resnets",
"conv_pre": "conv_in",
"conv_post": "conv_out",
}
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
state_dict[new_key] = state_dict.pop(old_key)
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
state_dict.pop(key)
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
# Skip if not a weight, bias
if ".weight" not in key and ".bias" not in key:
return
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
if key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
param = state_dict.pop(key)
state_dict[new_key] = param
return
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
if key.startswith("per_channel_statistics"):
new_key = ".".join(["decoder", key])
param = state_dict.pop(key)
state_dict[new_key] = param
return
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
"video_embeddings_connector": remove_keys_inplace,
"audio_embeddings_connector": remove_keys_inplace,
"adaln_single": convert_ltx2_transformer_adaln_single,
}
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
"connectors.": "",
"video_embeddings_connector": "video_connector",
"audio_embeddings_connector": "audio_connector",
"transformer_1d_blocks": "transformer_blocks",
"text_embedding_projection.aggregate_embed": "text_proj_in",
# Attention QK Norms
"q_norm": "norm_q",
"k_norm": "norm_k",
}
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
"per_channel_statistics.channel": remove_keys_inplace,
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
}
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
connector_prefixes = (
"video_embeddings_connector",
"audio_embeddings_connector",
"transformer_1d_blocks",
"text_embedding_projection.aggregate_embed",
"connectors.",
"video_connector",
"audio_connector",
"text_proj_in",
)
transformer_state_dict, connector_state_dict = {}, {}
for key, value in state_dict.items():
if key.startswith(connector_prefixes):
connector_state_dict[key] = value
else:
transformer_state_dict[key] = value
return transformer_state_dict, connector_state_dict
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 4,
"out_channels": 4,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 2,
"attention_head_dim": 8,
"cross_attention_dim": 16,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 4,
"audio_out_channels": 4,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 2,
"audio_attention_head_dim": 4,
"audio_cross_attention_dim": 8,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 2,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 16,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1,
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"out_channels": 128,
"patch_size": 1,
"patch_size_t": 1,
"num_attention_heads": 32,
"attention_head_dim": 128,
"cross_attention_dim": 4096,
"vae_scale_factors": (8, 32, 32),
"pos_embed_max_pos": 20,
"base_height": 2048,
"base_width": 2048,
"audio_in_channels": 128,
"audio_out_channels": 128,
"audio_patch_size": 1,
"audio_patch_size_t": 1,
"audio_num_attention_heads": 32,
"audio_attention_head_dim": 64,
"audio_cross_attention_dim": 2048,
"audio_scale_factor": 4,
"audio_pos_embed_max_pos": 20,
"audio_sampling_rate": 16000,
"audio_hop_length": 160,
"num_layers": 48,
"activation_fn": "gelu-approximate",
"qk_norm": "rms_norm_across_heads",
"norm_elementwise_affine": False,
"norm_eps": 1e-6,
"caption_channels": 3840,
"attention_bias": True,
"attention_out_bias": True,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_offset": 1,
"timestep_scale_multiplier": 1000,
"cross_attn_timestep_scale_multiplier": 1000,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"caption_channels": 16,
"text_proj_in_factor": 3,
"video_connector_num_attention_heads": 4,
"video_connector_attention_head_dim": 8,
"video_connector_num_layers": 1,
"video_connector_num_learnable_registers": None,
"audio_connector_num_attention_heads": 4,
"audio_connector_attention_head_dim": 8,
"audio_connector_num_layers": 1,
"audio_connector_num_learnable_registers": None,
"connector_rope_base_seq_len": 32,
"rope_theta": 10000.0,
"rope_double_precision": False,
"causal_temporal_positioning": False,
},
}
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"caption_channels": 3840,
"text_proj_in_factor": 49,
"video_connector_num_attention_heads": 30,
"video_connector_attention_head_dim": 128,
"video_connector_num_layers": 2,
"video_connector_num_learnable_registers": 128,
"audio_connector_num_attention_heads": 30,
"audio_connector_attention_head_dim": 128,
"audio_connector_num_layers": 2,
"audio_connector_num_learnable_registers": 128,
"connector_rope_base_seq_len": 4096,
"rope_theta": 10000.0,
"rope_double_precision": True,
"causal_temporal_positioning": False,
"rope_type": "split",
},
}
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
special_keys_remap = {}
return config, rename_dict, special_keys_remap
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
diffusers_config = config["diffusers_config"]
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
with init_empty_weights():
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(transformer_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(transformer_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(transformer_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, transformer_state_dict)
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
return transformer
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
diffusers_config = config["diffusers_config"]
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
if len(connector_state_dict) == 0:
raise ValueError("No connector weights found in the provided state dict.")
with init_empty_weights():
connectors = LTX2TextConnectors.from_config(diffusers_config)
for key in list(connector_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(connector_state_dict, key, new_key)
for key in list(connector_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, connector_state_dict)
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
return connectors
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "test":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
elif version == "2.0":
config = {
"model_id": "diffusers-internal-dev/dummy-ltx2",
"diffusers_config": {
"in_channels": 3,
"out_channels": 3,
"latent_channels": 128,
"block_out_channels": (256, 512, 1024, 2048),
"down_block_types": (
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
"LTX2VideoDownBlock3D",
),
"decoder_block_out_channels": (256, 512, 1024),
"layers_per_block": (4, 6, 6, 2, 2),
"decoder_layers_per_block": (5, 5, 5, 5),
"spatio_temporal_scaling": (True, True, True, True),
"decoder_spatio_temporal_scaling": (True, True, True),
"decoder_inject_noise": (False, False, False, False),
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
"upsample_residual": (True, True, True),
"upsample_factor": (2, 2, 2),
"timestep_conditioning": False,
"patch_size": 4,
"patch_size_t": 1,
"resnet_norm_eps": 1e-6,
"encoder_causal": True,
"decoder_causal": False,
"encoder_spatial_padding_mode": "zeros",
"decoder_spatial_padding_mode": "reflect",
"spatial_compression_ratio": 32,
"temporal_compression_ratio": 8,
},
}
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"base_channels": 128,
"output_channels": 2,
"ch_mult": (1, 2, 4),
"num_res_blocks": 2,
"attn_resolutions": None,
"in_channels": 2,
"resolution": 256,
"latent_channels": 8,
"norm_type": "pixel",
"causality_axis": "height",
"dropout": 0.0,
"mid_block_add_attention": False,
"sample_rate": 16000,
"mel_hop_length": 160,
"is_causal": True,
"mel_bins": 64,
"double_z": True,
},
}
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vae.load_state_dict(original_state_dict, strict=True, assign=True)
return vae
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
if version == "2.0":
config = {
"model_id": "diffusers-internal-dev/new-ltx-model",
"diffusers_config": {
"in_channels": 128,
"hidden_channels": 1024,
"out_channels": 2,
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_factors": [6, 5, 2, 2, 2],
"resnet_kernel_sizes": [3, 7, 11],
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"leaky_relu_negative_slope": 0.1,
"output_sampling_rate": 24000,
},
}
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
return config, rename_dict, special_keys_remap
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
diffusers_config = config["diffusers_config"]
with init_empty_weights():
vocoder = LTX2Vocoder.from_config(diffusers_config)
# Handle official code --> diffusers key remapping via the remap dict
for key in list(original_state_dict.keys()):
new_key = key[:]
for replace_key, rename_key in rename_dict.items():
new_key = new_key.replace(replace_key, rename_key)
update_state_dict_inplace(original_state_dict, key, new_key)
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(original_state_dict.keys()):
for special_key, handler_fn_inplace in special_keys_remap.items():
if special_key not in key:
continue
handler_fn_inplace(key, original_state_dict)
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
return vocoder
def get_ltx2_spatial_latent_upsampler_config(version: str):
if version == "2.0":
config = {
"in_channels": 128,
"mid_channels": 1024,
"num_blocks_per_stage": 4,
"dims": 3,
"spatial_upsample": True,
"temporal_upsample": False,
"rational_spatial_scale": 2.0,
}
else:
raise ValueError(f"Unsupported version: {version}")
return config
def convert_ltx2_spatial_latent_upsampler(
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
):
with init_empty_weights():
latent_upsampler = LTX2LatentUpsamplerModel(**config)
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
latent_upsampler.to(dtype)
return latent_upsampler
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
if args.original_state_dict_repo_id is not None:
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
elif args.checkpoint_path is not None:
ckpt_path = args.checkpoint_path
else:
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
original_state_dict = safetensors.torch.load_file(ckpt_path)
return original_state_dict
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
else:
ckpt_path = filename
_, ext = os.path.splitext(ckpt_path)
if ext in [".safetensors", ".sft"]:
state_dict = safetensors.torch.load_file(ckpt_path)
else:
state_dict = torch.load(ckpt_path, map_location="cpu")
return state_dict
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
# Ensure that the key prefix ends with a dot (.)
if not prefix.endswith("."):
prefix = prefix + "."
model_state_dict = {}
for param_name, param in combined_ckpt.items():
if param_name.startswith(prefix):
model_state_dict[param_name.replace(prefix, "")] = param
if prefix == "model.diffusion_model.":
# Some checkpoints store the text connector projection outside the diffusion model prefix.
connector_key = "text_embedding_projection.aggregate_embed.weight"
if connector_key in combined_ckpt and connector_key not in model_state_dict:
model_state_dict[connector_key] = combined_ckpt[connector_key]
return model_state_dict
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--original_state_dict_repo_id",
default="Lightricks/LTX-2",
type=str,
help="HF Hub repo id with LTX 2.0 checkpoint",
)
parser.add_argument(
"--checkpoint_path",
default=None,
type=str,
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
)
parser.add_argument(
"--version",
type=str,
default="2.0",
choices=["test", "2.0"],
help="Version of the LTX 2.0 model",
)
parser.add_argument(
"--combined_filename",
default="ltx-2-19b-dev.safetensors",
type=str,
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
)
parser.add_argument("--vae_prefix", default="vae.", type=str)
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
parser.add_argument(
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
)
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
parser.add_argument(
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
)
parser.add_argument(
"--text_encoder_model_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 base text encoder model",
)
parser.add_argument(
"--tokenizer_id",
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
type=str,
help="HF Hub id for the LTX 2.0 text tokenizer",
)
parser.add_argument(
"--latent_upsampler_filename",
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
type=str,
help="Latent upsampler filename",
)
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
parser.add_argument(
"--full_pipeline",
action="store_true",
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
)
parser.add_argument(
"--upsample_pipeline",
action="store_true",
help="Whether to save a latent upsampling pipeline",
)
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
return parser.parse_args()
DTYPE_MAPPING = {
"fp32": torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}
VARIANT_MAPPING = {
"fp32": None,
"fp16": "fp16",
"bf16": "bf16",
}
def main(args):
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
combined_ckpt = None
load_combined_models = any(
[
args.vae,
args.audio_vae,
args.dit,
args.vocoder,
args.text_encoder,
args.full_pipeline,
args.upsample_pipeline,
]
)
if args.combined_filename is not None and load_combined_models:
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
if args.vae or args.full_pipeline or args.upsample_pipeline:
if args.vae_filename is not None:
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
elif combined_ckpt is not None:
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
if not args.full_pipeline and not args.upsample_pipeline:
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
if args.audio_vae or args.full_pipeline:
if args.audio_vae_filename is not None:
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
elif combined_ckpt is not None:
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
if not args.full_pipeline:
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
if args.dit or args.full_pipeline:
if args.dit_filename is not None:
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
if not args.full_pipeline:
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
if args.connectors or args.full_pipeline:
if args.dit_filename is not None:
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
elif combined_ckpt is not None:
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
if not args.full_pipeline:
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
if args.vocoder or args.full_pipeline:
if args.vocoder_filename is not None:
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
elif combined_ckpt is not None:
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
if not args.full_pipeline:
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
if args.text_encoder or args.full_pipeline:
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
if not args.full_pipeline:
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
if not args.full_pipeline:
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
)
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
original_latent_upsampler_ckpt,
latent_upsampler_config,
dtype=vae_dtype,
)
if not args.full_pipeline and not args.upsample_pipeline:
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
if args.full_pipeline:
scheduler = FlowMatchEulerDiscreteScheduler(
use_dynamic_shifting=True,
base_shift=0.95,
max_shift=2.05,
base_image_seq_len=1024,
max_image_seq_len=4096,
shift_terminal=0.1,
)
pipe = LTX2Pipeline(
scheduler=scheduler,
vae=vae,
audio_vae=audio_vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
connectors=connectors,
transformer=transformer,
vocoder=vocoder,
)
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
if args.upsample_pipeline:
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
pipe.save_pretrained(
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
)
if __name__ == "__main__":
args = get_args()
main(args)

View File

@@ -274,7 +274,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
setup(
name="diffusers",
version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
description="State-of-the-art diffusion in PyTorch and JAX.",
long_description=open("README.md", "r", encoding="utf-8").read(),
long_description_content_type="text/markdown",

View File

@@ -23,7 +23,6 @@ from .utils import (
is_torchao_available,
is_torchsde_available,
is_transformers_available,
is_transformers_version,
)
@@ -194,8 +193,6 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
"AutoencoderKLMagvit",
"AutoencoderKLMochi",
@@ -226,7 +223,6 @@ else:
"FluxControlNetModel",
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -240,7 +236,6 @@ else:
"Kandinsky5Transformer3DModel",
"LatteTransformer3DModel",
"LongCatImageTransformer2DModel",
"LTX2VideoTransformer3DModel",
"LTXVideoTransformer3DModel",
"Lumina2Transformer2DModel",
"LuminaNextDiT2DModel",
@@ -284,7 +279,6 @@ else:
"WanAnimateTransformer3DModel",
"WanTransformer3DModel",
"WanVACETransformer3DModel",
"ZImageControlNetModel",
"ZImageTransformer2DModel",
"attention_backend",
]
@@ -358,7 +352,6 @@ else:
"KDPM2AncestralDiscreteScheduler",
"KDPM2DiscreteScheduler",
"LCMScheduler",
"LTXEulerAncestralRFScheduler",
"PNDMScheduler",
"RePaintScheduler",
"SASolverScheduler",
@@ -423,8 +416,6 @@ else:
"QwenImageEditModularPipeline",
"QwenImageEditPlusAutoBlocks",
"QwenImageEditPlusModularPipeline",
"QwenImageLayeredAutoBlocks",
"QwenImageLayeredModularPipeline",
"QwenImageModularPipeline",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLModularPipeline",
@@ -471,7 +462,6 @@ else:
"CogView4ControlPipeline",
"CogView4Pipeline",
"ConsisIDPipeline",
"Cosmos2_5_PredictBasePipeline",
"Cosmos2TextToImagePipeline",
"Cosmos2VideoToWorldPipeline",
"CosmosTextToWorldPipeline",
@@ -494,7 +484,6 @@ else:
"FluxKontextPipeline",
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -546,11 +535,7 @@ else:
"LEditsPPPipelineStableDiffusionXL",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
"LTXI2VLongMultiPromptPipeline",
"LTXImageToVideoPipeline",
"LTXLatentUpsamplePipeline",
"LTXPipeline",
@@ -685,10 +670,7 @@ else:
"WuerstchenCombinedPipeline",
"WuerstchenDecoderPipeline",
"WuerstchenPriorPipeline",
"ZImageControlNetInpaintPipeline",
"ZImageControlNetPipeline",
"ZImageImg2ImgPipeline",
"ZImageOmniPipeline",
"ZImagePipeline",
]
)
@@ -950,8 +932,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -982,7 +962,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxControlNetModel,
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -996,7 +975,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,
@@ -1039,7 +1017,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WanAnimateTransformer3DModel,
WanTransformer3DModel,
WanVACETransformer3DModel,
ZImageControlNetModel,
ZImageTransformer2DModel,
attention_backend,
)
@@ -1105,7 +1082,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
KDPM2AncestralDiscreteScheduler,
KDPM2DiscreteScheduler,
LCMScheduler,
LTXEulerAncestralRFScheduler,
PNDMScheduler,
RePaintScheduler,
SASolverScheduler,
@@ -1153,8 +1129,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageLayeredAutoBlocks,
QwenImageLayeredModularPipeline,
QwenImageModularPipeline,
StableDiffusionXLAutoBlocks,
StableDiffusionXLModularPipeline,
@@ -1197,7 +1171,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4ControlPipeline,
CogView4Pipeline,
ConsisIDPipeline,
Cosmos2_5_PredictBasePipeline,
Cosmos2TextToImagePipeline,
Cosmos2VideoToWorldPipeline,
CosmosTextToWorldPipeline,
@@ -1220,7 +1193,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxKontextPipeline,
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -1272,11 +1244,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusionXL,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ImageToVideoPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
LTXI2VLongMultiPromptPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
LTXPipeline,
@@ -1409,10 +1377,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
WuerstchenCombinedPipeline,
WuerstchenDecoderPipeline,
WuerstchenPriorPipeline,
ZImageControlNetInpaintPipeline,
ZImageControlNetPipeline,
ZImageImg2ImgPipeline,
ZImageOmniPipeline,
ZImagePipeline,
)

View File

@@ -25,7 +25,6 @@ if is_torch_available():
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
from .guider_utils import BaseGuidance
from .magnitude_aware_guidance import MagnitudeAwareGuidance
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
from .smoothed_energy_guidance import SmoothedEnergyGuidance

View File

@@ -1,159 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch
from ..configuration_utils import register_to_config
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
if TYPE_CHECKING:
from ..modular_pipelines.modular_pipeline import BlockState
class MagnitudeAwareGuidance(BaseGuidance):
"""
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
Args:
guidance_scale (`float`, defaults to `10.0`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
alpha (`float`, defaults to `8.0`):
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
guidance scale when the magnitude of the guidance update is large.
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
start (`float`, defaults to `0.0`):
The fraction of the total number of denoising steps after which guidance starts.
stop (`float`, defaults to `1.0`):
The fraction of the total number of denoising steps after which guidance stops.
"""
_input_predictions = ["pred_cond", "pred_uncond"]
@register_to_config
def __init__(
self,
guidance_scale: float = 10.0,
alpha: float = 8.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
enabled: bool = True,
):
super().__init__(start, stop, enabled)
self.guidance_scale = guidance_scale
self.alpha = alpha
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def prepare_inputs_from_block_state(
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
) -> List["BlockState"]:
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
data_batches.append(data_batch)
return data_batches
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
pred = None
if not self._is_mambo_g_enabled():
pred = pred_cond
else:
pred = mambo_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.alpha,
self.use_original_formulation,
)
if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
@property
def is_conditional(self) -> bool:
return self._count_prepared == 1
@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_mambo_g_enabled():
num_conditions += 1
return num_conditions
def _is_mambo_g_enabled(self) -> bool:
if not self._enabled:
return False
is_within_range = True
if self._num_inference_steps is not None:
skip_start_step = int(self._start * self._num_inference_steps)
skip_stop_step = int(self._stop * self._num_inference_steps)
is_within_range = skip_start_step <= self._step < skip_stop_step
is_close = False
if self.use_original_formulation:
is_close = math.isclose(self.guidance_scale, 0.0)
else:
is_close = math.isclose(self.guidance_scale, 1.0)
return is_within_range and not is_close
def mambo_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
alpha: float = 8.0,
use_original_formulation: bool = False,
):
dim = list(range(1, len(pred_cond.shape)))
diff = pred_cond - pred_uncond
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
guidance_scale_final = (
guidance_scale * torch.exp(-alpha * ratio)
if use_original_formulation
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
)
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale_final * diff
return pred

View File

@@ -67,7 +67,6 @@ if is_torch_available():
"SD3LoraLoaderMixin",
"AuraFlowLoraLoaderMixin",
"StableDiffusionXLLoraLoaderMixin",
"LTX2LoraLoaderMixin",
"LTXVideoLoraLoaderMixin",
"LoraLoaderMixin",
"FluxLoraLoaderMixin",
@@ -122,7 +121,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,
LoraLoaderMixin,
LTX2LoraLoaderMixin,
LTXVideoLoraLoaderMixin,
Lumina2LoraLoaderMixin,
Mochi1LoraLoaderMixin,

View File

@@ -2140,54 +2140,6 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
return converted_state_dict
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
# Remove the prefix
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
if non_diffusers_prefix == "diffusion_model":
rename_dict = {
"patchify_proj": "proj_in",
"audio_patchify_proj": "audio_proj_in",
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
"q_norm": "norm_q",
"k_norm": "norm_k",
}
else:
rename_dict = {"aggregate_embed": "text_proj_in"}
# Apply renaming
renamed_state_dict = {}
for key, value in converted_state_dict.items():
new_key = key[:]
for old_pattern, new_pattern in rename_dict.items():
new_key = new_key.replace(old_pattern, new_pattern)
renamed_state_dict[new_key] = value
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
final_state_dict = {}
for key, value in renamed_state_dict.items():
if key.startswith("adaln_single."):
new_key = key.replace("adaln_single.", "time_embed.")
final_state_dict[new_key] = value
elif key.startswith("audio_adaln_single."):
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
final_state_dict[new_key] = value
else:
final_state_dict[key] = value
# Add transformer prefix
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
return final_state_dict
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
if has_diffusion_model:

View File

@@ -48,7 +48,6 @@ from .lora_conversion_utils import (
_convert_non_diffusers_flux2_lora_to_diffusers,
_convert_non_diffusers_hidream_lora_to_diffusers,
_convert_non_diffusers_lora_to_diffusers,
_convert_non_diffusers_ltx2_lora_to_diffusers,
_convert_non_diffusers_ltxv_lora_to_diffusers,
_convert_non_diffusers_lumina2_lora_to_diffusers,
_convert_non_diffusers_qwen_lora_to_diffusers,
@@ -75,7 +74,6 @@ logger = logging.get_logger(__name__)
TEXT_ENCODER_NAME = "text_encoder"
UNET_NAME = "unet"
TRANSFORMER_NAME = "transformer"
LTX2_CONNECTOR_NAME = "connectors"
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
@@ -214,7 +212,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
@@ -641,7 +639,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_unet(
state_dict,
@@ -1081,7 +1079,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -1377,7 +1375,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -1659,7 +1657,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
)
if not (has_lora_keys or has_norm_keys):
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
transformer_lora_state_dict = {
k: state_dict.get(k)
@@ -2506,7 +2504,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -2703,7 +2701,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -2906,7 +2904,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3013,233 +3011,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class LTX2LoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
"""
_lora_loadable_modules = ["transformer", "connectors"]
transformer_name = TRANSFORMER_NAME
connectors_name = LTX2_CONNECTOR_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
final_state_dict = state_dict
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
if is_non_diffusers_format:
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
if has_connector:
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
state_dict, "text_embedding_projection"
)
final_state_dict.update(connectors_state_dict)
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
adapter_name: Optional[str] = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
transformer_peft_state_dict = {
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
}
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
self.load_lora_into_transformer(
transformer_peft_state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
if connectors_peft_state_dict:
self.load_lora_into_transformer(
connectors_peft_state_dict,
transformer=getattr(self, self.connectors_name)
if not hasattr(self, "connectors")
else self.connectors,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=self.connectors_name,
)
@classmethod
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
prefix: str = "transformer",
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {prefix}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
prefix=prefix,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: Union[str, os.PathLike],
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: Optional[dict] = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: List[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: Optional[List[str]] = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class SanaLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
@@ -3333,7 +3104,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3536,7 +3307,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3740,7 +3511,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -3940,7 +3711,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -4194,7 +3965,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4471,7 +4242,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
if load_into_transformer_2:
@@ -4691,7 +4462,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -4894,7 +4665,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5100,7 +4871,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5306,7 +5077,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,
@@ -5509,7 +5280,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
raise ValueError("Invalid LoRA checkpoint.")
self.load_lora_into_transformer(
state_dict,

View File

@@ -63,12 +63,9 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}

View File

@@ -49,7 +49,6 @@ from .single_file_utils import (
convert_stable_cascade_unet_single_file_to_diffusers,
convert_wan_transformer_to_diffusers,
convert_wan_vae_to_diffusers,
convert_z_image_controlnet_checkpoint_to_diffusers,
convert_z_image_transformer_checkpoint_to_diffusers,
create_controlnet_diffusers_config_from_ldm,
create_unet_diffusers_config_from_ldm,
@@ -162,7 +161,7 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"default_subfolder": "transformer",
},
"QwenImageTransformer2DModel": {
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
"checkpoint_mapping_fn": lambda x: x,
"default_subfolder": "transformer",
},
"Flux2Transformer2DModel": {
@@ -173,18 +172,11 @@ SINGLE_FILE_LOADABLE_CLASSES = {
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
"default_subfolder": "transformer",
},
"ZImageControlNetModel": {
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
},
}
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
model_state_dict_keys = set(model_state_dict.keys())
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
is_match = model_state_dict_keys == checkpoint_state_dict_keys
return not (is_subset and is_match)
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
def _get_single_file_loadable_mapping_class(cls):

View File

@@ -120,12 +120,7 @@ CHECKPOINT_KEY_NAMES = {
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
"z-image-turbo": [
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
"layers.0.adaLN_modulation.0.weight",
],
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
"z-image-turbo": "cap_embedder.0.weight",
"sana": [
"blocks.0.cross_attn.q_linear.weight",
"blocks.0.cross_attn.q_linear.bias",
@@ -225,9 +220,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
}
# Use to configure model sample size when original config is provided
@@ -731,7 +723,10 @@ def infer_diffusers_model_type(checkpoint):
):
model_type = "instruct-pix2pix"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
elif (
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
):
model_type = "z-image-turbo"
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
@@ -784,18 +779,6 @@ def infer_diffusers_model_type(checkpoint):
else:
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
before_proj_weight = checkpoint.get("control_noise_refiner.0.before_proj.weight", None)
if before_proj_weight is None:
model_type = "z-image-turbo-controlnet-2.0"
elif before_proj_weight is not None and torch.all(before_proj_weight == 0.0):
model_type = "z-image-turbo-controlnet-2.0"
else:
model_type = "z-image-turbo-controlnet-2.1"
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
model_type = "z-image-turbo-controlnet"
else:
model_type = "v1"
@@ -3859,7 +3842,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
".attention.k_norm.weight": ".attention.norm_k.weight",
".attention.q_norm.weight": ".attention.norm_q.weight",
".attention.out.weight": ".attention.to_out.0.weight",
"model.diffusion_model.": "",
}
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
@@ -3894,9 +3876,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
update_state_dict(converted_state_dict, key, new_key)
if "norm_final.weight" in converted_state_dict.keys():
_ = converted_state_dict.pop("norm_final.weight")
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
# special_keys_remap
for key in list(converted_state_dict.keys()):
@@ -3906,17 +3885,3 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
handler_fn_inplace(key, converted_state_dict)
return converted_state_dict
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
if config["add_control_noise_refiner"] is None:
return checkpoint
elif config["add_control_noise_refiner"] == "control_noise_refiner":
return checkpoint
elif config["add_control_noise_refiner"] == "control_layers":
converted_state_dict = {
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
}
return converted_state_dict
else:
raise ValueError("Unknown Z-Image Turbo ControlNet type.")

View File

@@ -41,8 +41,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
@@ -68,7 +66,6 @@ if is_torch_available():
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
_import_structure["embeddings"] = ["ImageProjection"]
@@ -98,7 +95,6 @@ if is_torch_available():
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -107,7 +103,6 @@ if is_torch_available():
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
@@ -157,8 +152,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
AutoencoderKLMagvit,
AutoencoderKLMochi,
@@ -188,7 +181,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
SD3MultiControlNetModel,
SparseControlNetModel,
UNetControlNetXSModel,
ZImageControlNetModel,
)
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
@@ -209,7 +201,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
EasyAnimateTransformer3DModel,
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,
@@ -219,7 +210,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Kandinsky5Transformer3DModel,
LatteTransformer3DModel,
LongCatImageTransformer2DModel,
LTX2VideoTransformer3DModel,
LTXVideoTransformer3DModel,
Lumina2Transformer2DModel,
LuminaNextDiT2DModel,

View File

@@ -90,6 +90,10 @@ class ContextParallelConfig:
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."

View File

@@ -1106,51 +1106,6 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: Optional["ParallelConfig"] = None,
):
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
return out
# Not implemented yet.
def _npu_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
# ===== Context parallel =====
@@ -1177,103 +1132,6 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
return x
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
"""
Perform dimension sharding / reassembly across processes using _all_to_all_single.
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
head dimension flexibly by accepting scatter_idx and gather_idx.
Args:
x (torch.Tensor):
Input tensor. Expected shapes:
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
scatter_idx (int) :
Dimension along which the tensor is partitioned before all-to-all.
gather_idx (int):
Dimension along which the output is reassembled after all-to-all.
group :
Distributed process group for the Ulysses group.
Returns:
torch.Tensor: Tensor with globally exchanged dimensions.
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
"""
group_world_size = torch.distributed.get_world_size(group)
if scatter_idx == 2 and gather_idx == 1:
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
# dimension and scatters head dimension
batch_size, seq_len_local, num_heads, head_dim = x.shape
seq_len = seq_len_local * group_world_size
num_heads_local = num_heads // group_world_size
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
x_temp = (
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
.transpose(0, 2)
.contiguous()
)
if group_world_size > 1:
out = _all_to_all_single(x_temp, group=group)
else:
out = x_temp
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
return out
elif scatter_idx == 1 and gather_idx == 2:
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
# scatters back the sequence dimension.
batch_size, seq_len, num_heads_local, head_dim = x.shape
num_heads = num_heads_local * group_world_size
seq_len_local = seq_len // group_world_size
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
x_temp = (
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
.permute(1, 3, 2, 0, 4)
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
)
if group_world_size > 1:
output = _all_to_all_single(x_temp, group)
else:
output = x_temp
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
return output
else:
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
class SeqAllToAllDim(torch.autograd.Function):
"""
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
for more info.
"""
@staticmethod
def forward(ctx, group, input, scatter_id=2, gather_id=1):
ctx.group = group
ctx.scatter_id = scatter_id
ctx.gather_id = gather_id
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
@staticmethod
def backward(ctx, grad_outputs):
grad_input = SeqAllToAllDim.apply(
ctx.group,
grad_outputs,
ctx.gather_id, # reversed
ctx.scatter_id, # reversed
)
return (None, grad_input, None, None)
class TemplatedRingAttention(torch.autograd.Function):
@staticmethod
def forward(
@@ -1334,10 +1192,7 @@ class TemplatedRingAttention(torch.autograd.Function):
out = out.to(torch.float32)
lse = lse.to(torch.float32)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1)
lse = lse.unsqueeze(-1)
if prev_out is not None:
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
@@ -1398,7 +1253,7 @@ class TemplatedRingAttention(torch.autograd.Function):
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
class TemplatedUlyssesAttention(torch.autograd.Function):
@@ -1493,69 +1348,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
)
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
def _templated_unified_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor],
dropout_p: float,
is_causal: bool,
scale: Optional[float],
enable_gqa: bool,
return_lse: bool,
forward_op,
backward_op,
_parallel_config: Optional["ParallelConfig"] = None,
scatter_idx: int = 2,
gather_idx: int = 1,
):
"""
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
"""
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
ulysses_group = ulysses_mesh.get_group()
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
out = TemplatedRingAttention.apply(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
if return_lse:
context_layer, lse, *_ = out
else:
context_layer = out
# context_layer is of shape (B, S, H_LOCAL, D)
output = SeqAllToAllDim.apply(
ulysses_group,
context_layer,
gather_idx,
scatter_idx,
)
if return_lse:
# lse is of shape (B, S, H_LOCAL, 1)
# Refer to:
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
if is_torch_version("<", "2.9.0"):
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
lse = lse.squeeze(-1)
return (output, lse)
return output
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
def _templated_context_parallel_attention(
@@ -1581,25 +1374,7 @@ def _templated_context_parallel_attention(
raise ValueError("GQA is not yet supported for templated attention.")
# TODO: add support for unified attention with ring/ulysses degree both being > 1
if (
_parallel_config.context_parallel_config.ring_degree > 1
and _parallel_config.context_parallel_config.ulysses_degree > 1
):
return _templated_unified_attention(
query,
key,
value,
attn_mask,
dropout_p,
is_causal,
scale,
enable_gqa,
return_lse,
forward_op,
backward_op,
_parallel_config,
)
elif _parallel_config.context_parallel_config.ring_degree > 1:
if _parallel_config.context_parallel_config.ring_degree > 1:
return TemplatedRingAttention.apply(
query,
key,
@@ -1645,7 +1420,6 @@ def _flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1653,9 +1427,6 @@ def _flash_attention(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
if _parallel_config is None:
out = flash_attn_func(
q=query,
@@ -1698,7 +1469,6 @@ def _flash_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -1706,9 +1476,6 @@ def _flash_attention_hub(
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
lse = None
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
out = func(
q=query,
@@ -1845,15 +1612,11 @@ def _flash_attention_3(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
out, lse = _wrapped_flash_attn_3(
q=query,
k=key,
@@ -1873,7 +1636,6 @@ def _flash_attention_3_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[float] = None,
is_causal: bool = False,
window_size: Tuple[int, int] = (-1, -1),
@@ -1884,8 +1646,6 @@ def _flash_attention_3_hub(
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
out = func(
@@ -2025,16 +1785,12 @@ def _aiter_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
if not return_lse and torch.is_grad_enabled():
# aiter requires return_lse=True by assertion when gradients are enabled.
out, lse, *_ = aiter_flash_attn_func(
@@ -2125,43 +1881,6 @@ def _native_flex_attention(
return out
def _prepare_additive_attn_mask(
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
) -> torch.Tensor:
"""
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
Args:
attn_mask: 2D tensor [batch_size, seq_len_k]
- Boolean: True means attend, False means mask out
- Additive: 0.0 means attend, -inf means mask out
target_dtype: The dtype to convert the mask to (usually query.dtype)
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
Returns:
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
"""
# Check if the mask is boolean or already additive
if attn_mask.dtype == torch.bool:
# Convert boolean to additive: True -> 0.0, False -> -inf
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
# Convert to target dtype
attn_mask = attn_mask.to(dtype=target_dtype)
else:
# Already additive mask - just ensure correct dtype
attn_mask = attn_mask.to(dtype=target_dtype)
# Optionally reshape to 4D for broadcasting in attention mechanisms
if reshape_4d:
batch_size, seq_len_k = attn_mask.shape
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
return attn_mask
@_AttentionBackendRegistry.register(
AttentionBackendName.NATIVE,
constraints=[_check_device, _check_shape],
@@ -2181,19 +1900,6 @@ def _native_attention(
) -> torch.Tensor:
if return_lse:
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
# Reshape 2D mask to 4D for SDPA
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
# SDPA handles both boolean and additive masks correctly
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
if _parallel_config is None:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
out = torch.nn.functional.scaled_dot_product_attention(
@@ -2322,7 +2028,6 @@ def _native_flash_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
@@ -2330,9 +2035,6 @@ def _native_flash_attention(
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for aiter attention")
lse = None
if _parallel_config is None and not return_lse:
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2406,52 +2108,34 @@ def _native_math_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName._NATIVE_NPU,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
)
def _native_npu_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
dropout_p: float = 0.0,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
None,
scale,
None,
return_lse,
forward_op=_npu_attention_forward_op,
backward_op=_npu_attention_backward_op,
_parallel_config=_parallel_config,
)
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
out = npu_fusion_attention(
query,
key,
value,
query.size(1), # num_heads
input_layout="BNSD",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1.0 - dropout_p,
sync=False,
inner_precise=0,
)[0]
out = out.transpose(1, 2).contiguous()
return out
@@ -2464,13 +2148,10 @@ def _native_xla_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for XLA attention")
if return_lse:
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
@@ -2494,14 +2175,11 @@ def _sage_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
if _parallel_config is None:
out = sageattn(
@@ -2545,14 +2223,11 @@ def _sage_attention_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
lse = None
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
if _parallel_config is None:
@@ -2634,14 +2309,11 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda(
q=query,
k=key,
@@ -2661,14 +2333,11 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp8_cuda_sm90(
q=query,
k=key,
@@ -2688,14 +2357,11 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_cuda(
q=query,
k=key,
@@ -2715,14 +2381,11 @@ def _sage_qk_int8_pv_fp16_triton_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
is_causal: bool = False,
scale: Optional[float] = None,
return_lse: bool = False,
_parallel_config: Optional["ParallelConfig"] = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for sage attention")
return sageattn_qk_int8_pv_fp16_triton(
q=query,
k=key,
@@ -2760,34 +2423,10 @@ def _xformers_attention(
attn_mask = xops.LowerTriangularMask()
elif attn_mask is not None:
if attn_mask.ndim == 2:
# Convert 2D mask to 4D for xformers
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
# Need memory alignment - create larger tensor and slice for alignment
original_seq_len = attn_mask.size(1)
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
# Create aligned 4D tensor and slice to ensure proper memory layout
aligned_mask = torch.zeros(
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
dtype=query.dtype,
device=query.device,
)
# Convert to 4D additive mask (handles both boolean and additive inputs)
mask_additive = _prepare_additive_attn_mask(
attn_mask, target_dtype=query.dtype
) # [batch, 1, 1, seq_len_k]
# Broadcast to [batch, heads, seq_q, seq_len_k]
aligned_mask[:, :, :, :original_seq_len] = mask_additive
# Mask out the padding (already -inf from zeros -> where with default)
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
# Slice to actual size with proper alignment
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
elif attn_mask.ndim != 4:
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
elif attn_mask.ndim == 4:
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
if enable_gqa:
if num_heads_q % num_heads_kv != 0:

View File

@@ -10,8 +10,6 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
from .autoencoder_kl_magvit import AutoencoderKLMagvit
from .autoencoder_kl_mochi import AutoencoderKLMochi
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage

View File

@@ -102,14 +102,14 @@ def get_block(
attention_head_dim: int,
norm_type: str,
act_fn: str,
qkv_multiscales: Tuple[int, ...] = (),
qkv_mutliscales: Tuple[int, ...] = (),
):
if block_type == "ResBlock":
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
elif block_type == "EfficientViTBlock":
block = EfficientViTBlock(
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
)
else:
@@ -247,7 +247,7 @@ class Encoder(nn.Module):
attention_head_dim=attention_head_dim,
norm_type="rms_norm",
act_fn="silu",
qkv_multiscales=qkv_multiscales[i],
qkv_mutliscales=qkv_multiscales[i],
)
down_block_list.append(block)
@@ -339,7 +339,7 @@ class Decoder(nn.Module):
attention_head_dim=attention_head_dim,
norm_type=norm_type[i],
act_fn=act_fn[i],
qkv_multiscales=qkv_multiscales[i],
qkv_mutliscales=qkv_multiscales[i],
)
up_block_list.append(block)

View File

@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
return h
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model for 2D images with spatial tiling support.
@@ -486,6 +486,27 @@ class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromO
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor):
batch_size, num_channels, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanImage-2.1 Refiner.
@@ -685,6 +685,27 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

View File

@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
return hidden_states
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
HunyuanVideo-1.5.
@@ -723,6 +723,27 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
_, _, _, height, width = x.shape

File diff suppressed because it is too large Load Diff

View File

@@ -1,804 +0,0 @@
# Copyright 2025 The Lightricks team and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
LATENT_DOWNSAMPLE_FACTOR = 4
class LTX2AudioCausalConv2d(nn.Module):
"""
A causal 2D convolution that pads asymmetrically along the causal axis.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: int = 1,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
pad_h = (kernel_size[0] - 1) * dilation[0]
pad_w = (kernel_size[1] - 1) * dilation[1]
if self.causality_axis == "none":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis in {"width", "width-compatibility"}:
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
elif self.causality_axis == "height":
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
else:
raise ValueError(f"Invalid causality_axis: {causality_axis}")
self.padding = padding
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, self.padding)
return self.conv(x)
class LTX2AudioPixelNorm(nn.Module):
"""
Per-pixel (per-location) RMS normalization layer.
"""
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
super().__init__()
self.dim = dim
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
rms = torch.sqrt(mean_sq + self.eps)
return x / rms
class LTX2AudioAttnBlock(nn.Module):
def __init__(
self,
in_channels: int,
norm_type: str = "group",
) -> None:
super().__init__()
self.in_channels = in_channels
if norm_type == "group":
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
h_ = self.norm(x)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
batch, channels, height, width = q.shape
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
k = k.reshape(batch, channels, height * width).contiguous()
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
attn = torch.nn.functional.softmax(attn, dim=2)
v = v.reshape(batch, channels, height * width)
attn = attn.permute(0, 2, 1).contiguous()
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
h_ = self.proj_out(h_)
return x + h_
class LTX2AudioResnetBlock(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
norm_type: str = "group",
causality_axis: str = "height",
) -> None:
super().__init__()
self.causality_axis = causality_axis
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if norm_type == "group":
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.non_linearity = nn.SiLU()
if causality_axis is not None:
self.conv1 = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if norm_type == "group":
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
elif norm_type == "pixel":
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {norm_type}")
self.dropout = nn.Dropout(dropout)
if causality_axis is not None:
self.conv2 = LTX2AudioCausalConv2d(
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
if causality_axis is not None:
self.conv_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
if causality_axis is not None:
self.nin_shortcut = LTX2AudioCausalConv2d(
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
)
else:
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
h = self.norm1(x)
h = self.non_linearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
h = self.norm2(h)
h = self.non_linearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
return x + h
class LTX2AudioDownsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.with_conv:
# Padding tuple is in the order: (left, right, top, bottom).
if self.causality_axis == "none":
pad = (0, 1, 0, 1)
elif self.causality_axis == "width":
pad = (2, 0, 0, 1)
elif self.causality_axis == "height":
pad = (0, 1, 2, 0)
elif self.causality_axis == "width-compatibility":
pad = (1, 0, 0, 1)
else:
raise ValueError(
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
f" and `width-compatibility`."
)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
# with_conv=False implies that causality_axis is "none"
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class LTX2AudioUpsample(nn.Module):
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
super().__init__()
self.with_conv = with_conv
self.causality_axis = causality_axis
if self.with_conv:
if causality_axis is not None:
self.conv = LTX2AudioCausalConv2d(
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
)
else:
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
if self.causality_axis is None or self.causality_axis == "none":
pass
elif self.causality_axis == "height":
x = x[:, :, 1:, :]
elif self.causality_axis == "width":
x = x[:, :, :, 1:]
elif self.causality_axis == "width-compatibility":
pass
else:
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
return x
class LTX2AudioAudioPatchifier:
"""
Patchifier for spectrogram/audio latents.
"""
def __init__(
self,
patch_size: int,
sample_rate: int = 16000,
hop_length: int = 160,
audio_latent_downsample_factor: int = 4,
is_causal: bool = True,
):
self.hop_length = hop_length
self.sample_rate = sample_rate
self.audio_latent_downsample_factor = audio_latent_downsample_factor
self.is_causal = is_causal
self._patch_size = (1, patch_size, patch_size)
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
batch, channels, time, freq = audio_latents.shape
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
batch, time, _ = audio_latents.shape
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
@property
def patch_size(self) -> Tuple[int, int, int]:
return self._patch_size
class LTX2AudioEncoder(nn.Module):
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
):
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels
base_resolution = resolution
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.down = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution
for level in range(self.num_resolutions):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != self.num_resolutions - 1:
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
curr_res = curr_res // 2
self.down.append(stage)
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
final_block_channels = block_in
z_channels = 2 * latent_channels if double_z else latent_channels
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
self.non_linearity = nn.SiLU()
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
hidden_states = self.conv_in(hidden_states)
for level in range(self.num_resolutions):
stage = self.down[level]
for block_idx, block in enumerate(stage.block):
hidden_states = block(hidden_states, temb=None)
if stage.attn:
hidden_states = stage.attn[block_idx](hidden_states)
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
hidden_states = stage.downsample(hidden_states)
hidden_states = self.mid.block_1(hidden_states, temb=None)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb=None)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.non_linearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class LTX2AudioDecoder(nn.Module):
"""
Symmetric decoder that reconstructs audio spectrograms from latent features.
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
convolutions.
"""
def __init__(
self,
base_channels: int = 128,
output_channels: int = 1,
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
ch_mult: Tuple[int, ...] = (1, 2, 4),
norm_type: str = "group",
causality_axis: Optional[str] = "width",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
) -> None:
super().__init__()
self.sample_rate = sample_rate
self.mel_hop_length = mel_hop_length
self.is_causal = is_causal
self.mel_bins = mel_bins
self.patchifier = LTX2AudioAudioPatchifier(
patch_size=1,
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
sample_rate=sample_rate,
hop_length=mel_hop_length,
is_causal=is_causal,
)
self.base_channels = base_channels
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.out_ch = output_channels
self.give_pre_end = False
self.tanh_out = False
self.norm_type = norm_type
self.latent_channels = latent_channels
self.channel_multipliers = ch_mult
self.attn_resolutions = attn_resolutions
self.causality_axis = causality_axis
base_block_channels = base_channels * self.channel_multipliers[-1]
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
if self.causality_axis is not None:
self.conv_in = LTX2AudioCausalConv2d(
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
self.non_linearity = nn.SiLU()
self.mid = nn.Module()
self.mid.block_1 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
if mid_block_add_attention:
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
else:
self.mid.attn_1 = nn.Identity()
self.mid.block_2 = LTX2AudioResnetBlock(
in_channels=base_block_channels,
out_channels=base_block_channels,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
self.up = nn.ModuleList()
block_in = base_block_channels
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
for level in reversed(range(self.num_resolutions)):
stage = nn.Module()
stage.block = nn.ModuleList()
stage.attn = nn.ModuleList()
block_out = self.base_channels * self.channel_multipliers[level]
for _ in range(self.num_res_blocks + 1):
stage.block.append(
LTX2AudioResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
norm_type=self.norm_type,
causality_axis=self.causality_axis,
)
)
block_in = block_out
if self.attn_resolutions:
if curr_res in self.attn_resolutions:
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
if level != 0:
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
curr_res *= 2
self.up.insert(0, stage)
final_block_channels = block_in
if self.norm_type == "group":
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
elif self.norm_type == "pixel":
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
else:
raise ValueError(f"Invalid normalization type: {self.norm_type}")
if self.causality_axis is not None:
self.conv_out = LTX2AudioCausalConv2d(
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
)
else:
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
def forward(
self,
sample: torch.Tensor,
) -> torch.Tensor:
_, _, frames, mel_bins = sample.shape
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
if self.causality_axis is not None:
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
target_channels = self.out_ch
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
hidden_features = self.conv_in(sample)
hidden_features = self.mid.block_1(hidden_features, temb=None)
hidden_features = self.mid.attn_1(hidden_features)
hidden_features = self.mid.block_2(hidden_features, temb=None)
for level in reversed(range(self.num_resolutions)):
stage = self.up[level]
for block_idx, block in enumerate(stage.block):
hidden_features = block(hidden_features, temb=None)
if stage.attn:
hidden_features = stage.attn[block_idx](hidden_features)
if level != 0 and hasattr(stage, "upsample"):
hidden_features = stage.upsample(hidden_features)
if self.give_pre_end:
return hidden_features
hidden = self.norm_out(hidden_features)
hidden = self.non_linearity(hidden)
decoded_output = self.conv_out(hidden)
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
_, _, current_time, current_freq = decoded_output.shape
target_time = target_frames
target_freq = target_mel_bins
decoded_output = decoded_output[
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
]
time_padding_needed = target_time - decoded_output.shape[2]
freq_padding_needed = target_freq - decoded_output.shape[3]
if time_padding_needed > 0 or freq_padding_needed > 0:
padding = (
0,
max(freq_padding_needed, 0),
0,
max(time_padding_needed, 0),
)
decoded_output = F.pad(decoded_output, padding)
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
return decoded_output
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
LTX2 audio VAE for encoding and decoding audio latent representations.
"""
_supports_gradient_checkpointing = False
@register_to_config
def __init__(
self,
base_channels: int = 128,
output_channels: int = 2,
ch_mult: Tuple[int, ...] = (1, 2, 4),
num_res_blocks: int = 2,
attn_resolutions: Optional[Tuple[int, ...]] = None,
in_channels: int = 2,
resolution: int = 256,
latent_channels: int = 8,
norm_type: str = "pixel",
causality_axis: Optional[str] = "height",
dropout: float = 0.0,
mid_block_add_attention: bool = False,
sample_rate: int = 16000,
mel_hop_length: int = 160,
is_causal: bool = True,
mel_bins: Optional[int] = 64,
double_z: bool = True,
) -> None:
super().__init__()
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
if causality_axis not in supported_causality_axes:
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
self.encoder = LTX2AudioEncoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
double_z=double_z,
)
self.decoder = LTX2AudioDecoder(
base_channels=base_channels,
output_channels=output_channels,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
attn_resolutions=attn_resolution_set,
in_channels=in_channels,
resolution=resolution,
latent_channels=latent_channels,
norm_type=norm_type,
causality_axis=causality_axis,
dropout=dropout,
mid_block_add_attention=mid_block_add_attention,
sample_rate=sample_rate,
mel_hop_length=mel_hop_length,
is_causal=is_causal,
mel_bins=mel_bins,
)
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
latents_std = torch.zeros((base_channels,))
latents_mean = torch.ones((base_channels,))
self.register_buffer("latents_mean", latents_mean, persistent=True)
self.register_buffer("latents_std", latents_std, persistent=True)
# TODO: calculate programmatically instead of hardcoding
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
# TODO: confirm whether the mel compression ratio below is correct
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor, return_dict: bool = True):
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
return self.decoder(z)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
posterior = self.encode(sample).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec.sample,)
return dec

View File

@@ -41,11 +41,9 @@ class CacheMixin:
Enable caching techniques on the model.
Args:
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
config (`Union[PyramidAttentionBroadcastConfig]`):
The configuration for applying the caching technique. Currently supported caching techniques are:
- [`~hooks.PyramidAttentionBroadcastConfig`]
- [`~hooks.FasterCacheConfig`]
- [`~hooks.FirstBlockCacheConfig`]
Example:

View File

@@ -0,0 +1,115 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate
from .controlnets.controlnet import ( # noqa
ControlNetConditioningEmbedding,
ControlNetModel,
ControlNetOutput,
zero_module,
)
class ControlNetOutput(ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class ControlNetModel(ControlNetModel):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 3,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
encoder_hid_dim: Optional[int] = None,
encoder_hid_dim_type: Optional[str] = None,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
addition_embed_type: Optional[str] = None,
addition_time_embed_dim: Optional[int] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
projection_class_embeddings_input_dim: Optional[int] = None,
controlnet_conditioning_channel_order: str = "rgb",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
addition_embed_type_num_heads: int = 64,
):
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
mid_block_type=mid_block_type,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
encoder_hid_dim=encoder_hid_dim,
encoder_hid_dim_type=encoder_hid_dim_type,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
class_embed_type=class_embed_type,
addition_embed_type=addition_embed_type,
addition_time_embed_dim=addition_time_embed_dim,
num_class_embeds=num_class_embeds,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
addition_embed_type_num_heads=addition_embed_type_num_heads,
)
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,70 @@
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from ..utils import deprecate, logging
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FluxControlNetOutput(FluxControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class FluxControlNetModel(FluxControlNetModel):
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: List[int] = [16, 56, 56],
num_mode: int = None,
conditioning_embedding_channels: int = None,
):
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
super().__init__(
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
num_single_layers=num_single_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
pooled_projection_dim=pooled_projection_dim,
guidance_embeds=guidance_embeds,
axes_dims_rope=axes_dims_rope,
num_mode=num_mode,
conditioning_embedding_channels=conditioning_embedding_channels,
)
class FluxMultiControlNetModel(FluxMultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,68 @@
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..utils import deprecate, logging
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SD3ControlNetOutput(SD3ControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SD3ControlNetModel(SD3ControlNetModel):
def __init__(
self,
sample_size: int = 128,
patch_size: int = 2,
in_channels: int = 16,
num_layers: int = 18,
attention_head_dim: int = 64,
num_attention_heads: int = 18,
joint_attention_dim: int = 4096,
caption_projection_dim: int = 1152,
pooled_projection_dim: int = 2048,
out_channels: int = 16,
pos_embed_max_size: int = 96,
extra_conditioning_channels: int = 0,
):
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
super().__init__(
sample_size=sample_size,
patch_size=patch_size,
in_channels=in_channels,
num_layers=num_layers,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
joint_attention_dim=joint_attention_dim,
caption_projection_dim=caption_projection_dim,
pooled_projection_dim=pooled_projection_dim,
out_channels=out_channels,
pos_embed_max_size=pos_embed_max_size,
extra_conditioning_channels=extra_conditioning_channels,
)
class SD3MultiControlNetModel(SD3MultiControlNetModel):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
super().__init__(*args, **kwargs)

View File

@@ -0,0 +1,116 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
from ..utils import deprecate, logging
from .controlnets.controlnet_sparsectrl import ( # noqa
SparseControlNetConditioningEmbedding,
SparseControlNetModel,
SparseControlNetOutput,
zero_module,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class SparseControlNetOutput(SparseControlNetOutput):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
super().__init__(*args, **kwargs)
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
def __init__(self, *args, **kwargs):
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
deprecate(
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
)
super().__init__(*args, **kwargs)
class SparseControlNetModel(SparseControlNetModel):
def __init__(
self,
in_channels: int = 4,
conditioning_channels: int = 4,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str, ...] = (
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"CrossAttnDownBlockMotion",
"DownBlockMotion",
),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: Optional[int] = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 768,
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
use_linear_projection: bool = False,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
global_pool_conditions: bool = False,
controlnet_conditioning_channel_order: str = "rgb",
motion_max_seq_length: int = 32,
motion_num_attention_heads: int = 8,
concat_conditioning_mask: bool = True,
use_simplified_condition_embedding: bool = True,
):
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
super().__init__(
in_channels=in_channels,
conditioning_channels=conditioning_channels,
flip_sin_to_cos=flip_sin_to_cos,
freq_shift=freq_shift,
down_block_types=down_block_types,
only_cross_attention=only_cross_attention,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
downsample_padding=downsample_padding,
mid_block_scale_factor=mid_block_scale_factor,
act_fn=act_fn,
norm_num_groups=norm_num_groups,
norm_eps=norm_eps,
cross_attention_dim=cross_attention_dim,
transformer_layers_per_block=transformer_layers_per_block,
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
global_pool_conditions=global_pool_conditions,
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
motion_max_seq_length=motion_max_seq_length,
motion_num_attention_heads=motion_num_attention_heads,
concat_conditioning_mask=concat_conditioning_mask,
use_simplified_condition_embedding=use_simplified_condition_embedding,
)

View File

@@ -19,7 +19,6 @@ if is_torch_available():
)
from .controlnet_union import ControlNetUnionModel
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
from .controlnet_z_image import ZImageControlNetModel
from .multicontrolnet import MultiControlNetModel
from .multicontrolnet_union import MultiControlNetUnionModel

View File

@@ -20,7 +20,7 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
from ..attention import AttentionMixin
from ..cache_utils import CacheMixin
from ..controlnets.controlnet import zero_module
@@ -31,7 +31,6 @@ from ..transformers.transformer_qwenimage import (
QwenImageTransformerBlock,
QwenTimestepProjEmbeddings,
RMSNorm,
compute_text_seq_len_from_mask,
)
@@ -137,7 +136,7 @@ class QwenImageControlNetModel(
return_dict: bool = True,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`QwenImageControlNetModel`] forward method.
The [`FluxTransformer2DModel`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
@@ -148,39 +147,24 @@ class QwenImageControlNetModel(
The scale factor for ControlNet outputs.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
Used to indicate denoising step.
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
Image shapes for RoPE computation.
txt_seq_lens (`List[int]`, *optional*):
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
length.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
the first element is the controlnet block samples.
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# Handle deprecated txt_seq_lens parameter
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
"and `encoder_hidden_states_mask`.",
standard_warn=False,
)
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
@@ -202,47 +186,32 @@ class QwenImageControlNetModel(
temb = self.time_text_embed(timestep, hidden_states)
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
encoder_hidden_states, encoder_hidden_states_mask
)
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
timestep = timestep.to(hidden_states.dtype)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
encoder_hidden_states = self.txt_in(encoder_hidden_states)
# Construct joint attention mask once to avoid reconstructing in every block
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
if encoder_hidden_states_mask is not None:
# Build joint mask: [text_mask, all_ones_for_image]
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
block_attention_kwargs["attention_mask"] = joint_attention_mask
block_samples = ()
for block in self.transformer_blocks:
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
encoder_hidden_states_mask,
temb,
image_rotary_emb,
block_attention_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
encoder_hidden_states_mask=encoder_hidden_states_mask,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=block_attention_kwargs,
joint_attention_kwargs=joint_attention_kwargs,
)
block_samples = block_samples + (hidden_states,)
@@ -298,15 +267,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[QwenImageControlNetOutput, Tuple]:
if txt_seq_lens is not None:
deprecate(
"txt_seq_lens",
"0.39.0",
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
standard_warn=False,
)
# ControlNet-Union with multiple conditions
# only load one ControlNet for saving memories
if len(self.nets) == 1:
@@ -321,6 +281,7 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
encoder_hidden_states_mask=encoder_hidden_states_mask,
timestep=timestep,
img_shapes=img_shapes,
txt_seq_lens=txt_seq_lens,
joint_attention_kwargs=joint_attention_kwargs,
return_dict=return_dict,
)

View File

@@ -1,845 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Literal, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...loaders.single_file_model import FromOriginalModelMixin
from ...models.attention_processor import Attention
from ...models.normalization import RMSNorm
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention_dispatch import dispatch_attention_fn
from ..controlnets.controlnet import zero_module
from ..modeling_utils import ModelMixin
ADALN_EMBED_DIM = 256
SEQ_MULTI_OF = 32
# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder
class TimestepEmbedder(nn.Module):
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
super().__init__()
if mid_size is None:
mid_size = out_size
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, mid_size, bias=True),
nn.SiLU(),
nn.Linear(mid_size, out_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
with torch.amp.autocast("cuda", enabled=False):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
weight_dtype = self.mlp[0].weight.dtype
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
if weight_dtype.is_floating_point:
t_freq = t_freq.to(weight_dtype)
elif compute_dtype is not None:
t_freq = t_freq.to(compute_dtype)
t_emb = self.mlp(t_freq)
return t_emb
# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor
class ZSingleStreamAttnProcessor:
"""
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
original Z-ImageAttention module.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
freqs_cis: Optional[torch.Tensor] = None,
) -> torch.Tensor:
query = attn.to_q(hidden_states)
key = attn.to_k(hidden_states)
value = attn.to_v(hidden_states)
query = query.unflatten(-1, (attn.heads, -1))
key = key.unflatten(-1, (attn.heads, -1))
value = value.unflatten(-1, (attn.heads, -1))
# Apply Norms
if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)
# Apply RoPE
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
with torch.amp.autocast("cuda", enabled=False):
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
freqs_cis = freqs_cis.unsqueeze(2)
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
return x_out.type_as(x_in) # todo
if freqs_cis is not None:
query = apply_rotary_emb(query, freqs_cis)
key = apply_rotary_emb(key, freqs_cis)
# Cast to correct dtype
dtype = query.dtype
query, key = query.to(dtype), key.to(dtype)
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
if attention_mask is not None and attention_mask.ndim == 2:
attention_mask = attention_mask[:, None, None, :]
# Compute joint attention
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape back
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.to(dtype)
output = attn.to_out[0](hidden_states)
if len(attn.to_out) > 1: # dropout
output = attn.to_out[1](output)
return output
# Copied from diffusers.models.transformers.transformer_z_image.FeedForward
class FeedForward(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
def _forward_silu_gating(self, x1, x3):
return F.silu(x1) * x3
def forward(self, x):
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
def select_per_token(
value_noisy: torch.Tensor,
value_clean: torch.Tensor,
noise_mask: torch.Tensor,
seq_len: int,
) -> torch.Tensor:
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
return torch.where(
noise_mask_expanded == 1,
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
)
@maybe_allow_in_graph
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
class ZImageTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
# Refactored to use diffusers Attention with custom processor
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
self.attention = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // n_heads,
heads=n_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
noise_mask: Optional[torch.Tensor] = None,
adaln_noisy: Optional[torch.Tensor] = None,
adaln_clean: Optional[torch.Tensor] = None,
):
if self.modulation:
seq_len = x.shape[1]
if noise_mask is not None:
# Per-token modulation: different modulation for noisy/clean tokens
mod_noisy = self.adaLN_modulation(adaln_noisy)
mod_clean = self.adaLN_modulation(adaln_clean)
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
else:
# Global modulation: same modulation for all tokens (avoid double select)
mod = self.adaLN_modulation(adaln_input)
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
x = x + gate_msa * self.attention_norm2(attn_out)
# FFN block
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
else:
# Attention block
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
x = x + self.attention_norm2(attn_out)
# FFN block
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
return x
# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder
class RopeEmbedder:
def __init__(
self,
theta: float = 256.0,
axes_dims: List[int] = (16, 56, 56),
axes_lens: List[int] = (64, 128, 128),
):
self.theta = theta
self.axes_dims = axes_dims
self.axes_lens = axes_lens
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
self.freqs_cis = None
@staticmethod
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
with torch.device("cpu"):
freqs_cis = []
for i, (d, e) in enumerate(zip(dim, end)):
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
freqs = torch.outer(timestep, freqs).float()
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
freqs_cis.append(freqs_cis_i)
return freqs_cis
def __call__(self, ids: torch.Tensor):
assert ids.ndim == 2
assert ids.shape[-1] == len(self.axes_dims)
device = ids.device
if self.freqs_cis is None:
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
else:
# Ensure freqs_cis are on the same device as ids
if self.freqs_cis[0].device != device:
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
result = []
for i in range(len(self.axes_dims)):
index = ids[:, i]
result.append(self.freqs_cis[i][index])
return torch.cat(result, dim=-1)
@maybe_allow_in_graph
class ZImageControlTransformerBlock(nn.Module):
def __init__(
self,
layer_id: int,
dim: int,
n_heads: int,
n_kv_heads: int,
norm_eps: float,
qk_norm: bool,
modulation=True,
block_id=0,
):
super().__init__()
self.dim = dim
self.head_dim = dim // n_heads
# Refactored to use diffusers Attention with custom processor
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
self.attention = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // n_heads,
heads=n_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-5,
bias=False,
out_bias=False,
processor=ZSingleStreamAttnProcessor(),
)
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
self.layer_id = layer_id
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
self.modulation = modulation
if modulation:
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
# Control variant start
self.block_id = block_id
if block_id == 0:
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
def forward(
self,
c: torch.Tensor,
x: torch.Tensor,
attn_mask: torch.Tensor,
freqs_cis: torch.Tensor,
adaln_input: Optional[torch.Tensor] = None,
):
# Control
if self.block_id == 0:
c = self.before_proj(c) + x
all_c = []
else:
all_c = list(torch.unbind(c))
c = all_c.pop(-1)
# Compared to `ZImageTransformerBlock` x -> c
if self.modulation:
assert adaln_input is not None
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
# Attention block
attn_out = self.attention(
self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
)
c = c + gate_msa * self.attention_norm2(attn_out)
# FFN block
c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp))
else:
# Attention block
attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis)
c = c + self.attention_norm2(attn_out)
# FFN block
c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c)))
# Control
c_skip = self.after_proj(c)
all_c += [c_skip, c]
c = torch.stack(all_c)
return c
class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
control_layers_places: List[int] = None,
control_refiner_layers_places: List[int] = None,
control_in_dim=None,
add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None,
all_patch_size=(2,),
all_f_patch_size=(1,),
dim=3840,
n_refiner_layers=2,
n_heads=30,
n_kv_heads=30,
norm_eps=1e-5,
qk_norm=True,
):
super().__init__()
self.control_layers_places = control_layers_places
self.control_in_dim = control_in_dim
self.control_refiner_layers_places = control_refiner_layers_places
self.add_control_noise_refiner = add_control_noise_refiner
assert 0 in self.control_layers_places
# control blocks
self.control_layers = nn.ModuleList(
[
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
for i in self.control_layers_places
]
)
# control patch embeddings
all_x_embedder = {}
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
if self.add_control_noise_refiner == "control_layers":
self.control_noise_refiner = None
elif self.add_control_noise_refiner == "control_noise_refiner":
self.control_noise_refiner = nn.ModuleList(
[
ZImageControlTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
block_id=layer_id,
)
for layer_id in range(n_refiner_layers)
]
)
else:
self.control_noise_refiner = nn.ModuleList(
[
ZImageTransformerBlock(
1000 + layer_id,
dim,
n_heads,
n_kv_heads,
norm_eps,
qk_norm,
modulation=True,
)
for layer_id in range(n_refiner_layers)
]
)
self.t_scale: Optional[float] = None
self.t_embedder: Optional[TimestepEmbedder] = None
self.all_x_embedder: Optional[nn.ModuleDict] = None
self.cap_embedder: Optional[nn.Sequential] = None
self.rope_embedder: Optional[RopeEmbedder] = None
self.noise_refiner: Optional[nn.ModuleList] = None
self.context_refiner: Optional[nn.ModuleList] = None
self.x_pad_token: Optional[nn.Parameter] = None
self.cap_pad_token: Optional[nn.Parameter] = None
@classmethod
def from_transformer(cls, controlnet, transformer):
controlnet.t_scale = transformer.t_scale
controlnet.t_embedder = transformer.t_embedder
controlnet.all_x_embedder = transformer.all_x_embedder
controlnet.cap_embedder = transformer.cap_embedder
controlnet.rope_embedder = transformer.rope_embedder
controlnet.noise_refiner = transformer.noise_refiner
controlnet.context_refiner = transformer.context_refiner
controlnet.x_pad_token = transformer.x_pad_token
controlnet.cap_pad_token = transformer.cap_pad_token
return controlnet
@staticmethod
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid
def create_coordinate_grid(size, start=None, device=None):
if start is None:
start = (0 for _ in size)
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
grids = torch.meshgrid(axes, indexing="ij")
return torch.stack(grids, dim=-1)
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
pH, pW, pF = patch_size, patch_size, f_patch_size
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
def _pad_with_ids(
self,
feat: torch.Tensor,
pos_grid_size: Tuple,
pos_start: Tuple,
device: torch.device,
noise_mask_val: Optional[int] = None,
):
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
ori_len = len(feat)
pad_len = (-ori_len) % SEQ_MULTI_OF
total_len = ori_len + pad_len
# Pos IDs
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
if pad_len > 0:
pad_pos_ids = (
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
.flatten(0, 2)
.repeat(pad_len, 1)
)
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
pad_mask = torch.cat(
[
torch.zeros(ori_len, dtype=torch.bool, device=device),
torch.ones(pad_len, dtype=torch.bool, device=device),
]
)
else:
pos_ids = ori_pos_ids
padded_feat = feat
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
def patchify_and_embed(
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
):
"""Patchify for basic mode: single image per batch item."""
device = all_image[0].device
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
for image, cap_feat in zip(all_image, all_cap_feats):
# Caption
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
)
all_cap_out.append(cap_out)
all_cap_pos_ids.append(cap_pos_ids)
all_cap_pad_mask.append(cap_pad_mask)
# Image
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
)
all_img_out.append(img_out)
all_img_size.append(size)
all_img_pos_ids.append(img_pos_ids)
all_img_pad_mask.append(img_pad_mask)
return (
all_img_out,
all_cap_out,
all_img_size,
all_img_pos_ids,
all_cap_pos_ids,
all_img_pad_mask,
all_cap_pad_mask,
)
def patchify(
self,
all_image: List[torch.Tensor],
patch_size: int,
f_patch_size: int,
):
pH = pW = patch_size
pF = f_patch_size
all_image_out = []
for i, image in enumerate(all_image):
### Process Image
C, F, H, W = image.size()
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
image_ori_len = len(image)
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
# padded feature
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
all_image_out.append(image_padded_feat)
return all_image_out
def forward(
self,
x: List[torch.Tensor],
t,
cap_feats: List[torch.Tensor],
control_context: List[torch.Tensor],
conditioning_scale: float = 1.0,
patch_size=2,
f_patch_size=1,
):
if (
self.t_scale is None
or self.t_embedder is None
or self.all_x_embedder is None
or self.cap_embedder is None
or self.rope_embedder is None
or self.noise_refiner is None
or self.context_refiner is None
or self.x_pad_token is None
or self.cap_pad_token is None
):
raise ValueError(
"Required modules are `None`, use `from_transformer` to share required modules from `transformer`."
)
assert patch_size in self.config.all_patch_size
assert f_patch_size in self.config.all_f_patch_size
bsz = len(x)
device = x[0].device
t = t * self.t_scale
t = self.t_embedder(t)
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
control_context = self.patchify(control_context, patch_size, f_patch_size)
control_context = torch.cat(control_context, dim=0)
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
control_context = list(control_context.split(x_item_seqlens, dim=0))
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
# x embed & refine
x = torch.cat(x, dim=0)
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
# Match t_embedder output dtype to x for layerwise casting compatibility
adaln_input = t.type_as(x)
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
x = pad_sequence(x, batch_first=True, padding_value=0.0)
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
if self.add_control_noise_refiner is not None:
if self.add_control_noise_refiner == "control_layers":
layers = self.control_layers
elif self.add_control_noise_refiner == "control_noise_refiner":
layers = self.control_noise_refiner
else:
raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.")
for layer in layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_context = self._gradient_checkpointing_func(
layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input
)
else:
control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input)
hints = torch.unbind(control_context)[:-1]
control_context = torch.unbind(control_context)[-1]
noise_refiner_block_samples = {
layer_idx: hints[idx] * conditioning_scale
for idx, layer_idx in enumerate(self.control_refiner_layers_places)
}
else:
noise_refiner_block_samples = None
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer_idx, layer in enumerate(self.noise_refiner):
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
if noise_refiner_block_samples is not None:
if layer_idx in noise_refiner_block_samples:
x = x + noise_refiner_block_samples[layer_idx]
else:
for layer_idx, layer in enumerate(self.noise_refiner):
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
if noise_refiner_block_samples is not None:
if layer_idx in noise_refiner_block_samples:
x = x + noise_refiner_block_samples[layer_idx]
# cap embed & refine
cap_item_seqlens = [len(_) for _ in cap_feats]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
)
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.context_refiner:
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
else:
for layer in self.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# unified
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
assert unified_item_seqlens == [len(_) for _ in unified]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
## ControlNet start
if not self.add_control_noise_refiner:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for layer in self.control_noise_refiner:
control_context = self._gradient_checkpointing_func(
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input
)
else:
for layer in self.control_noise_refiner:
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
# unified
control_context_unified = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]]))
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
for layer in self.control_layers:
if torch.is_grad_enabled() and self.gradient_checkpointing:
control_context_unified = self._gradient_checkpointing_func(
layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
)
else:
control_context_unified = layer(
control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
)
hints = torch.unbind(control_context_unified)[:-1]
controlnet_block_samples = {
layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)
}
return controlnet_block_samples

View File

@@ -47,7 +47,6 @@ from ..utils import (
is_torch_version,
logging,
)
from ..utils.distributed_utils import is_torch_dist_rank_zero
logger = logging.get_logger(__name__)
@@ -355,9 +354,8 @@ def _load_shard_file(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
disable_mmap=False,
):
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
mismatched_keys = _find_mismatched_keys(
state_dict,
model_state_dict,
@@ -403,7 +401,6 @@ def _load_shard_files_with_threadpool(
state_dict_folder=None,
ignore_mismatched_sizes=False,
low_cpu_mem_usage=False,
disable_mmap=False,
):
# Do not spawn anymore workers than you need
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
@@ -430,15 +427,10 @@ def _load_shard_files_with_threadpool(
state_dict_folder=state_dict_folder,
ignore_mismatched_sizes=ignore_mismatched_sizes,
low_cpu_mem_usage=low_cpu_mem_usage,
disable_mmap=disable_mmap,
)
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
if not is_torch_dist_rank_zero():
tqdm_kwargs["disable"] = True
with ThreadPoolExecutor(max_workers=num_workers) as executor:
with logging.tqdm(**tqdm_kwargs) as pbar:
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
for future in as_completed(futures):
result = future.result()

Some files were not shown because too many files have changed in this diff Show More