mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-07 03:15:16 +08:00
Compare commits
6 Commits
lighter-gl
...
naykun-mai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
49e39f8a4e | ||
|
|
02851fade5 | ||
|
|
c0b58138e5 | ||
|
|
f3c62427cb | ||
|
|
e630e6eca4 | ||
|
|
a89f13eff4 |
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
4
.github/workflows/build_docker_images.yml
vendored
4
.github/workflows/build_docker_images.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Login to Docker Hub
|
||||
|
||||
4
.github/workflows/build_pr_documentation.yml
vendored
4
.github/workflows/build_pr_documentation.yml
vendored
@@ -17,10 +17,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
22
.github/workflows/codeql.yml
vendored
22
.github/workflows/codeql.yml
vendored
@@ -1,22 +0,0 @@
|
||||
---
|
||||
name: CodeQL Security Analysis For Github Actions
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
# pull_request:
|
||||
|
||||
jobs:
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
|
||||
permissions:
|
||||
security-events: write
|
||||
packages: read
|
||||
actions: read
|
||||
contents: read
|
||||
with:
|
||||
languages: '["actions","python"]'
|
||||
queries: 'security-extended,security-and-quality'
|
||||
runner: 'ubuntu-latest' #optional if need custom runner
|
||||
29
.github/workflows/mirror_community_pipeline.yml
vendored
29
.github/workflows/mirror_community_pipeline.yml
vendored
@@ -24,6 +24,7 @@ jobs:
|
||||
mirror_community_pipeline:
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_COMMUNITY_MIRROR }}
|
||||
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
# Checkout to correct ref
|
||||
@@ -38,41 +39,37 @@ jobs:
|
||||
# If ref is 'refs/heads/main' => set 'main'
|
||||
# Else it must be a tag => set {tag}
|
||||
- name: Set checkout_ref and path_in_repo
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
run: |
|
||||
if [ "$EVENT_NAME" == "workflow_dispatch" ]; then
|
||||
if [ -z "$EVENT_INPUT_REF" ]; then
|
||||
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
|
||||
if [ -z "${{ github.event.inputs.ref }}" ]; then
|
||||
echo "Error: Missing ref input"
|
||||
exit 1
|
||||
elif [ "$EVENT_INPUT_REF" == "main" ]; then
|
||||
elif [ "${{ github.event.inputs.ref }}" == "main" ]; then
|
||||
echo "CHECKOUT_REF=refs/heads/main" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
|
||||
else
|
||||
echo "CHECKOUT_REF=refs/tags/$EVENT_INPUT_REF" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=$EVENT_INPUT_REF" >> $GITHUB_ENV
|
||||
echo "CHECKOUT_REF=refs/tags/${{ github.event.inputs.ref }}" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=${{ github.event.inputs.ref }}" >> $GITHUB_ENV
|
||||
fi
|
||||
elif [ "$GITHUB_REF" == "refs/heads/main" ]; then
|
||||
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
|
||||
elif [ "${{ github.ref }}" == "refs/heads/main" ]; then
|
||||
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=main" >> $GITHUB_ENV
|
||||
else
|
||||
# e.g. refs/tags/v0.28.1 -> v0.28.1
|
||||
echo "CHECKOUT_REF=$GITHUB_REF" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=$(echo $GITHUB_REF | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
|
||||
echo "CHECKOUT_REF=${{ github.ref }}" >> $GITHUB_ENV
|
||||
echo "PATH_IN_REPO=$(echo ${{ github.ref }} | sed 's/^refs\/tags\///')" >> $GITHUB_ENV
|
||||
fi
|
||||
- name: Print env vars
|
||||
run: |
|
||||
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
|
||||
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ env.CHECKOUT_REF }}
|
||||
|
||||
# Setup + install dependencies
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -102,4 +99,4 @@ jobs:
|
||||
- name: Report failure status
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
|
||||
pip install requests && python utils/notify_community_pipelines_mirror.py --status=failure
|
||||
46
.github/workflows/nightly_tests.yml
vendored
46
.github/workflows/nightly_tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -119,7 +119,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file, examples]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_${{ matrix.module }}_cuda_test_reports
|
||||
path: reports
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -263,7 +263,7 @@ jobs:
|
||||
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_big_gpu_test_reports
|
||||
path: reports
|
||||
@@ -280,7 +280,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -321,7 +321,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -355,7 +355,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -391,7 +391,7 @@ jobs:
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.config.backend }}_reports
|
||||
path: reports
|
||||
@@ -408,7 +408,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -441,7 +441,7 @@ jobs:
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
@@ -466,7 +466,7 @@ jobs:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -474,7 +474,7 @@ jobs:
|
||||
run: mkdir -p combined_reports
|
||||
|
||||
- name: Download all test reports
|
||||
uses: actions/download-artifact@v7
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
@@ -500,7 +500,7 @@ jobs:
|
||||
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Upload consolidated report
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: consolidated_test_report
|
||||
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
|
||||
@@ -514,7 +514,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v6
|
||||
# uses: actions/checkout@v3
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -554,7 +554,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v6
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v6
|
||||
# uses: actions/checkout@v3
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -610,7 +610,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v6
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
|
||||
@@ -10,10 +10,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
|
||||
4
.github/workflows/pr_dependency_test.yml
vendored
4
.github/workflows/pr_dependency_test.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
check_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
38
.github/workflows/pr_modular_tests.yml
vendored
38
.github/workflows/pr_modular_tests.yml
vendored
@@ -1,4 +1,3 @@
|
||||
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
@@ -36,9 +35,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -56,9 +55,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -78,13 +77,23 @@ jobs:
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
group: aws-highmemory-32-plus
|
||||
group: ${{ matrix.config.runner }}
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
image: ${{ matrix.config.image }}
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
@@ -93,7 +102,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -109,19 +118,22 @@ jobs:
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_torch_cpu_modular_pipelines \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
|
||||
12
.github/workflows/pr_test_fetcher.yml
vendored
12
.github/workflows/pr_test_fetcher.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
test_map: ${{ steps.set_matrix.outputs.test_map }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
run: |
|
||||
python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
- name: Report fetched tests
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: test_fetched
|
||||
path: test_preparation.txt
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v3
|
||||
with:
|
||||
name: ${{ matrix.modules }}_test_reports
|
||||
path: reports
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -164,7 +164,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
20
.github/workflows/pr_tests.yml
vendored
20
.github/workflows/pr_tests.yml
vendored
@@ -31,9 +31,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -51,9 +51,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -153,7 +153,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -185,7 +185,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -236,7 +236,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -273,7 +273,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
24
.github/workflows/pr_tests_gpu.yml
vendored
24
.github/workflows/pr_tests_gpu.yml
vendored
@@ -32,9 +32,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -52,9 +52,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -170,7 +170,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -193,7 +193,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -239,7 +239,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -255,7 +255,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -287,7 +287,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -18,9 +18,9 @@ jobs:
|
||||
check_torch_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
24
.github/workflows/push_tests.yml
vendored
24
.github/workflows/push_tests.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -240,7 +240,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -286,7 +286,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_fast.yml
vendored
4
.github/workflows/push_tests_fast.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_mps.yml
vendored
4
.github/workflows/push_tests_mps.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pr_torch_mps_test_reports
|
||||
path: reports
|
||||
|
||||
8
.github/workflows/pypi_publish.yaml
vendored
8
.github/workflows/pypi_publish.yaml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
@@ -40,12 +40,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
28
.github/workflows/release_tests_fast.yml
vendored
28
.github/workflows/release_tests_fast.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -116,7 +116,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -205,7 +205,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -264,7 +264,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -305,7 +305,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -336,7 +336,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
2
.github/workflows/run_tests_from_a_pr.yml
vendored
2
.github/workflows/run_tests_from_a_pr.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
shell: bash -e {0}
|
||||
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ref: refs/pull/${{ inputs.pr_number }}/head
|
||||
|
||||
|
||||
2
.github/workflows/ssh-pr-runner.yml
vendored
2
.github/workflows/ssh-pr-runner.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
2
.github/workflows/ssh-runner.yml
vendored
2
.github/workflows/ssh-runner.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
4
.github/workflows/stale.yml
vendored
4
.github/workflows/stale.yml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v6
|
||||
uses: actions/setup-python@v1
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
|
||||
2
.github/workflows/trufflehog.yml
vendored
2
.github/workflows/trufflehog.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v6
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
|
||||
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.12.4
|
||||
|
||||
2
.github/workflows/update_metadata.yml
vendored
2
.github/workflows/update_metadata.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
- uses: actions/checkout@v3
|
||||
|
||||
- name: Setup environment
|
||||
run: |
|
||||
|
||||
@@ -54,8 +54,6 @@
|
||||
title: Batch inference
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
- local: hybrid_inference/overview
|
||||
title: Remote inference
|
||||
title: Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -90,6 +88,17 @@
|
||||
title: FreeU
|
||||
title: Community optimizations
|
||||
title: Inference optimization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: modular_diffusers/overview
|
||||
@@ -261,8 +270,6 @@
|
||||
title: Outputs
|
||||
- local: api/quantization
|
||||
title: Quantization
|
||||
- local: hybrid_inference/api_reference
|
||||
title: Remote inference
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
title: Main Classes
|
||||
@@ -346,8 +353,6 @@
|
||||
title: Flux2Transformer2DModel
|
||||
- local: api/models/flux_transformer
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -362,8 +367,6 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/longcat_image_transformer2d
|
||||
title: LongCatImageTransformer2DModel
|
||||
- local: api/models/ltx2_video_transformer3d
|
||||
title: LTX2VideoTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
@@ -440,10 +443,6 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
title: AutoencoderKLLTX2Video
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -542,8 +541,6 @@
|
||||
title: Flux2
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/glm_image
|
||||
title: GLM-Image
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
@@ -681,8 +678,6 @@
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx2
|
||||
title: LTX-2
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
|
||||
@@ -29,7 +29,7 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
## FirstBlockCacheConfig
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
|
||||
@@ -33,7 +33,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`QwenImageLoraLoaderMixin`] provides similar functions for [Qwen Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/qwen).
|
||||
- [`ZImageLoraLoaderMixin`] provides similar functions for [Z-Image](https://huggingface.co/docs/diffusers/main/en/api/pipelines/zimage).
|
||||
- [`Flux2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/flux2).
|
||||
- [`LTX2LoraLoaderMixin`] provides similar functions for [Flux2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx2).
|
||||
- [`LoraBaseMixin`] provides a base class with several utility methods to fuse, unfuse, unload, LoRAs and more.
|
||||
|
||||
> [!TIP]
|
||||
@@ -63,10 +62,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.Flux2LoraLoaderMixin
|
||||
|
||||
## LTX2LoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.LTX2LoraLoaderMixin
|
||||
|
||||
## CogVideoXLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.CogVideoXLoraLoaderMixin
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Audio
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Audio
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Audio
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
@@ -1,29 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Video
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Video
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Video
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
|
||||
|
||||
## FluxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
|
||||
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
|
||||
@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
|
||||
|
||||
## SparseControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
|
||||
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# GlmImageTransformer2DModel
|
||||
|
||||
A Diffusion Transformer model for 2D data from [GlmImageTransformer2DModel] (TODO).
|
||||
|
||||
## GlmImageTransformer2DModel
|
||||
|
||||
[[autodoc]] GlmImageTransformer2DModel
|
||||
@@ -1,26 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTX2VideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTX2VideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTX2VideoTransformer3DModel
|
||||
@@ -30,10 +30,6 @@
|
||||
|
||||
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
|
||||
|
||||
Available Models/LoRAs:
|
||||
- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)
|
||||
- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)
|
||||
- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)
|
||||
|
||||
### Image Editing
|
||||
|
||||
@@ -104,7 +100,6 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
@@ -114,8 +109,9 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
|
||||
pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0)
|
||||
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
|
||||
pipe.load_lora_weights(lora_path)
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -149,57 +145,6 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
### Inference with Multiple LoRAs
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush")
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
|
||||
pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"Turn the pencil sketch in the image into an actual object that is consistent with the image’s content. The user wants to change the sketch to a crown and a hat."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png")
|
||||
```
|
||||
|
||||
## ChronoEditPipeline
|
||||
|
||||
[[autodoc]] ChronoEditPipeline
|
||||
|
||||
@@ -70,12 +70,6 @@ output.save("output.png")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
@@ -21,7 +21,7 @@ The abstract from the paper is:
|
||||
|
||||
*Image generation has recently seen tremendous advances, with diffusion models allowing to synthesize convincing images for a large variety of text prompts. In this article, we propose DiffEdit, a method to take advantage of text-conditioned diffusion models for the task of semantic image editing, where the goal is to edit an image based on a text query. Semantic image editing is an extension of image generation, with the additional constraint that the generated image should be as similar as possible to a given input image. Current editing methods based on diffusion models usually require to provide a mask, making the task much easier by treating it as a conditional inpainting task. In contrast, our main contribution is able to automatically generate a mask highlighting regions of the input image that need to be edited, by contrasting predictions of a diffusion model conditioned on different text prompts. Moreover, we rely on latent inference to preserve content in those regions of interest and show excellent synergies with mask-based diffusion. DiffEdit achieves state-of-the-art editing performance on ImageNet. In addition, we evaluate semantic image editing in more challenging settings, using images from the COCO dataset as well as text-based generated images.*
|
||||
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/posts/2022-11-02-diffedit-implementation.html).
|
||||
The original codebase can be found at [Xiang-cd/DiffEdit-stable-diffusion](https://github.com/Xiang-cd/DiffEdit-stable-diffusion), and you can try it out in this [demo](https://blog.problemsolversguild.com/technical/research/2022/11/02/DiffEdit-Implementation.html).
|
||||
|
||||
This pipeline was contributed by [clarencechen](https://github.com/clarencechen). ❤️
|
||||
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
-->
|
||||
|
||||
# GLM-Image
|
||||
|
||||
## Overview
|
||||
|
||||
GLM-Image is an image generation model adopts a hybrid autoregressive + diffusion decoder architecture, effectively pushing the upper bound of visual fidelity and fine-grained details. In general image generation quality, it aligns with industry-standard LDM-based approaches, while demonstrating significant advantages in knowledge-intensive image generation scenarios.
|
||||
|
||||
Model architecture: a hybrid autoregressive + diffusion decoder design、
|
||||
|
||||
+ Autoregressive generator: a 9B-parameter model initialized from [GLM-4-9B-0414](https://huggingface.co/zai-org/GLM-4-9B-0414), with an expanded vocabulary to incorporate visual tokens. The model first generates a compact encoding of approximately 256 tokens, then expands to 1K–4K tokens, corresponding to 1K–2K high-resolution image outputs. You can check AR model in class `GlmImageForConditionalGeneration` of `transformers` library.
|
||||
+ Diffusion Decoder: a 7B-parameter decoder based on a single-stream DiT architecture for latent-space image decoding. It is equipped with a Glyph Encoder text module, significantly improving accurate text rendering within images.
|
||||
|
||||
Post-training with decoupled reinforcement learning: the model introduces a fine-grained, modular feedback strategy using the GRPO algorithm, substantially enhancing both semantic understanding and visual detail quality.
|
||||
|
||||
+ Autoregressive module: provides low-frequency feedback signals focused on aesthetics and semantic alignment, improving instruction following and artistic expressiveness.
|
||||
+ Decoder module: delivers high-frequency feedback targeting detail fidelity and text accuracy, resulting in highly realistic textures, lighting, and color reproduction, as well as more precise text rendering.
|
||||
|
||||
GLM-Image supports both text-to-image and image-to-image generation within a single model
|
||||
|
||||
+ Text-to-image: generates high-detail images from textual descriptions, with particularly strong performance in information-dense scenarios.
|
||||
+ Image-to-image: supports a wide range of tasks, including image editing, style transfer, multi-subject consistency, and identity-preserving generation for people and objects.
|
||||
|
||||
This pipeline was contributed by [zRzRzRzRzRzRzR](https://github.com/zRzRzRzRzRzRzR). The codebase can be found [here](https://huggingface.co/zai-org/GLM-Image).
|
||||
|
||||
## Usage examples
|
||||
|
||||
### Text to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
prompt = "A beautifully designed modern food magazine style dessert recipe illustration, themed around a raspberry mousse cake. The overall layout is clean and bright, divided into four main areas: the top left features a bold black title 'Raspberry Mousse Cake Recipe Guide', with a soft-lit close-up photo of the finished cake on the right, showcasing a light pink cake adorned with fresh raspberries and mint leaves; the bottom left contains an ingredient list section, titled 'Ingredients' in a simple font, listing 'Flour 150g', 'Eggs 3', 'Sugar 120g', 'Raspberry puree 200g', 'Gelatin sheets 10g', 'Whipping cream 300ml', and 'Fresh raspberries', each accompanied by minimalist line icons (like a flour bag, eggs, sugar jar, etc.); the bottom right displays four equally sized step boxes, each containing high-definition macro photos and corresponding instructions, arranged from top to bottom as follows: Step 1 shows a whisk whipping white foam (with the instruction 'Whip egg whites to stiff peaks'), Step 2 shows a red-and-white mixture being folded with a spatula (with the instruction 'Gently fold in the puree and batter'), Step 3 shows pink liquid being poured into a round mold (with the instruction 'Pour into mold and chill for 4 hours'), Step 4 shows the finished cake decorated with raspberries and mint leaves (with the instruction 'Decorate with raspberries and mint'); a light brown information bar runs along the bottom edge, with icons on the left representing 'Preparation time: 30 minutes', 'Cooking time: 20 minutes', and 'Servings: 8'. The overall color scheme is dominated by creamy white and light pink, with a subtle paper texture in the background, featuring compact and orderly text and image layout with clear information hierarchy."
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
height=32 * 32,
|
||||
width=36 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_t2i.png")
|
||||
```
|
||||
|
||||
### Image to Image Generation
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers.pipelines.glm_image import GlmImagePipeline
|
||||
from PIL import Image
|
||||
|
||||
pipe = GlmImagePipeline.from_pretrained("zai-org/GLM-Image",torch_dtype=torch.bfloat16,device_map="cuda")
|
||||
image_path = "cond.jpg"
|
||||
prompt = "Replace the background of the snow forest with an underground station featuring an automatic escalator."
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
image = pipe(
|
||||
prompt=prompt,
|
||||
image=[image], # can input multiple images for multi-image-to-image generation such as [image, image1]
|
||||
height=33 * 32,
|
||||
width=32 * 32,
|
||||
num_inference_steps=30,
|
||||
guidance_scale=1.5,
|
||||
generator=torch.Generator(device="cuda").manual_seed(42),
|
||||
).images[0]
|
||||
|
||||
image.save("output_i2i.png")
|
||||
```
|
||||
|
||||
+ Since the AR model used in GLM-Image is configured with `do_sample=True` and a temperature of `0.95` by default, the generated images can vary significantly across runs. We do not recommend setting do_sample=False, as this may lead to incorrect or degenerate outputs from the AR model.
|
||||
|
||||
## GlmImagePipeline
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_glm_image.GlmImagePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## GlmImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.glm_image.pipeline_output.GlmImagePipelineOutput
|
||||
@@ -1,47 +0,0 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX-2
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTX2ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2LatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTX2LatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
|
||||
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
|
||||
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
|
||||
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
|
||||
|
||||
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
|
||||
|
||||
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
<details>
|
||||
<summary>Show example code</summary>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
@@ -474,12 +474,6 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
</details>
|
||||
|
||||
## LTXI2VLongMultiPromptPipeline
|
||||
|
||||
[[autodoc]] LTXI2VLongMultiPromptPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png")
|
||||
|
||||
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
|
||||
|
||||
```py
|
||||
```
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
@@ -108,46 +108,12 @@ pipe = QwenImageEditPlusPipeline.from_pretrained(
|
||||
image_1 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/grumpy.jpg")
|
||||
image_2 = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/peng.png")
|
||||
image = pipe(
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
image=[image_1, image_2],
|
||||
prompt='''put the penguin and the cat at a game show called "Qwen Edit Plus Games"''',
|
||||
num_inference_steps=50
|
||||
).images[0]
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
### torch.compile
|
||||
|
||||
Using `torch.compile` on the transformer provides ~2.4x speedup (A100 80GB: 4.70s → 1.93s):
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import QwenImagePipeline
|
||||
|
||||
pipe = QwenImagePipeline.from_pretrained("Qwen/Qwen-Image", torch_dtype=torch.bfloat16).to("cuda")
|
||||
pipe.transformer = torch.compile(pipe.transformer)
|
||||
|
||||
# First call triggers compilation (~7s overhead)
|
||||
# Subsequent calls run at ~2.4x faster
|
||||
image = pipe("a cat", num_inference_steps=50).images[0]
|
||||
```
|
||||
|
||||
### Batched Inference with Variable-Length Prompts
|
||||
|
||||
When using classifier-free guidance (CFG) with prompts of different lengths, the pipeline properly handles padding through attention masking. This ensures padding tokens do not influence the generated output.
|
||||
|
||||
```python
|
||||
# CFG with different prompt lengths works correctly
|
||||
image = pipe(
|
||||
prompt="A cat",
|
||||
negative_prompt="blurry, low quality, distorted",
|
||||
true_cfg_scale=3.5,
|
||||
num_inference_steps=50,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
For detailed benchmark scripts and results, see [this gist](https://gist.github.com/cdutr/bea337e4680268168550292d7819dc2f).
|
||||
|
||||
## QwenImagePipeline
|
||||
|
||||
[[autodoc]] QwenImagePipeline
|
||||
|
||||
@@ -37,8 +37,7 @@ The following SkyReels-V2 models are supported in Diffusers:
|
||||
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
|
||||
|
||||
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
|
||||
|
||||
> [!TIP]
|
||||
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
|
||||
|
||||
@@ -250,6 +250,9 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
# Remote inference
|
||||
# Hybrid Inference API Reference
|
||||
|
||||
Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.
|
||||
|
||||
## remote_decode
|
||||
## Remote Decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## remote_encode
|
||||
## Remote Encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -10,296 +10,51 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Remote inference
|
||||
# Hybrid Inference
|
||||
|
||||
**Empowering local AI builders with Hybrid Inference**
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.
|
||||
|
||||
| Model | Endpoint | Checkpoint | Support |
|
||||
|---|---|---|---|
|
||||
| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |
|
||||
| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |
|
||||
| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |
|
||||
| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |
|
||||
|
||||
This guide will show you how to encode and decode latents with remote inference.
|
||||
|
||||
## Encoding
|
||||
|
||||
Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.
|
||||
|
||||
Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
image=init_image,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159
|
||||
)
|
||||
```
|
||||
|
||||
## Decoding
|
||||
|
||||
Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.
|
||||
|
||||
Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
|
||||
|
||||
<hfoptions id="decode">
|
||||
<hfoption id="Flux">
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
prompt = """
|
||||
A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.
|
||||
"""
|
||||
|
||||
latent = pipeline(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="HunyuanVideo">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
|
||||
latent = pipeline(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Queuing
|
||||
|
||||
Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.
|
||||
|
||||
```py
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.13025,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.",
|
||||
"A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.",
|
||||
"A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.",
|
||||
"A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moon’s surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.",
|
||||
"A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos."
|
||||
]
|
||||
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
|
||||
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)
|
||||
pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
## Why use Hybrid Inference?
|
||||
|
||||
_ = pipeline(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
Hybrid Inference offers a fast and simple way to offload local generation requirements.
|
||||
|
||||
for prompt in prompts:
|
||||
latent = pipeline(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
|
||||
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
|
||||
- 💰 **Cost Effective:** It's free! 🤑
|
||||
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
|
||||
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
|
||||
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
---
|
||||
|
||||
## Benchmarks
|
||||
## Available Models
|
||||
|
||||
The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
|
||||
For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.
|
||||
---
|
||||
|
||||
<details><summary>Encoding - Stable Diffusion v1.5</summary>
|
||||
## Integrations
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
|
||||
</details>
|
||||
## Changelog
|
||||
|
||||
<details><summary>Encoding SDXL</summary>
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
## Contents
|
||||
|
||||
</details>
|
||||
The documentation is organized into three sections:
|
||||
|
||||
<details><summary>Decoding - Stable Diffusion v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Decoding SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).
|
||||
- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more.
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
|
||||
345
docs/source/en/hybrid_inference/vae_decode.md
Normal file
345
docs/source/en/hybrid_inference/vae_decode.md
Normal file
@@ -0,0 +1,345 @@
|
||||
# Getting Started: VAE Decode with Hybrid Inference
|
||||
|
||||
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Here, we show how to use the remote VAE on random tensors.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
|
||||
</figure>
|
||||
|
||||
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
|
||||
</figure>
|
||||
|
||||
Finally, an example for HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
|
||||
output_type="mp4",
|
||||
)
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s another example with Flux.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s an example with HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
latent = pipe(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Queueing
|
||||
|
||||
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
|
||||
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
|
||||
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
|
||||
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
|
||||
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
|
||||
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
|
||||
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
|
||||
]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"Lykon/dreamshaper-8",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
_ = pipe(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
183
docs/source/en/hybrid_inference/vae_encode.md
Normal file
183
docs/source/en/hybrid_inference/vae_encode.md
Normal file
@@ -0,0 +1,183 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -140,7 +140,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
@@ -256,7 +256,7 @@ class Florence2ImageAnnotatorBlock(ModularPipelineBlocks):
|
||||
type_hint=str,
|
||||
required=True,
|
||||
default="mask_image",
|
||||
description="""Output type from annotation predictions. Available options are
|
||||
description="""Output type from annotation predictions. Availabe options are
|
||||
mask_image:
|
||||
-black and white mask image for the given image based on the task type
|
||||
mask_overlay:
|
||||
|
||||
@@ -53,7 +53,7 @@ The loop wrapper can pass additional arguments, like current iteration index, to
|
||||
|
||||
A loop block is a [`~modular_pipelines.ModularPipelineBlocks`], but the `__call__` method behaves differently.
|
||||
|
||||
- It receives the iteration variable from the loop wrapper.
|
||||
- It recieves the iteration variable from the loop wrapper.
|
||||
- It works directly with the [`~modular_pipelines.BlockState`] instead of the [`~modular_pipelines.PipelineState`].
|
||||
- It doesn't require retrieving or updating the [`~modular_pipelines.BlockState`].
|
||||
|
||||
|
||||
@@ -68,20 +68,6 @@ config = FasterCacheConfig(
|
||||
pipeline.transformer.enable_cache(config)
|
||||
```
|
||||
|
||||
## FirstBlockCache
|
||||
|
||||
[FirstBlock Cache](https://huggingface.co/docs/diffusers/main/en/api/cache#diffusers.FirstBlockCacheConfig) checks how much the early layers of the denoiser changes from one timestep to the next. If the change is small, the model skips the expensive later layers and reuses the previous output.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.hooks import apply_first_block_cache, FirstBlockCacheConfig
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"Qwen/Qwen-Image", torch_dtype=torch.bfloat16
|
||||
)
|
||||
apply_first_block_cache(pipeline.transformer, FirstBlockCacheConfig(threshold=0.2))
|
||||
```
|
||||
## TaylorSeer Cache
|
||||
|
||||
[TaylorSeer Cache](https://huggingface.co/papers/2403.06923) accelerates diffusion inference by using Taylor series expansions to approximate and cache intermediate activations across denoising steps. The method predicts future outputs based on past computations, reusing them at specified intervals to reduce redundant calculations.
|
||||
@@ -101,7 +87,8 @@ from diffusers import FluxPipeline, TaylorSeerCacheConfig
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
)
|
||||
pipe.to("cuda")
|
||||
|
||||
config = TaylorSeerCacheConfig(
|
||||
cache_interval=5,
|
||||
@@ -110,4 +97,4 @@ config = TaylorSeerCacheConfig(
|
||||
taylor_factors_dtype=torch.bfloat16,
|
||||
)
|
||||
pipe.transformer.enable_cache(config)
|
||||
```
|
||||
```
|
||||
@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=pipeline_quant_config,
|
||||
quantzation_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
@@ -263,8 +263,8 @@ def main():
|
||||
world_size = dist.get_world_size()
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to(device)
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device_map=device
|
||||
)
|
||||
pipeline.transformer.set_attention_backend("_native_cudnn")
|
||||
|
||||
cp_config = ContextParallelConfig(ring_degree=world_size)
|
||||
@@ -314,35 +314,6 @@ Pass the [`ContextParallelConfig`] to [`~ModelMixin.enable_parallelism`].
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2))
|
||||
```
|
||||
|
||||
### Unified Attention
|
||||
|
||||
[Unified Sequence Parallelism](https://huggingface.co/papers/2405.07719) combines Ring Attention and Ulysses Attention into a single approach for efficient long-sequence processing. It applies Ulysses's *all-to-all* communication first to redistribute heads and sequence tokens, then uses Ring Attention to process the redistributed data, and finally reverses the *all-to-all* to restore the original layout.
|
||||
|
||||
This hybrid approach leverages the strengths of both methods:
|
||||
- **Ulysses Attention** efficiently parallelizes across attention heads
|
||||
- **Ring Attention** handles very long sequences with minimal memory overhead
|
||||
- Together, they enable 2D parallelization across both heads and sequence dimensions
|
||||
|
||||
[`ContextParallelConfig`] supports Unified Attention by specifying both `ulysses_degree` and `ring_degree`. The total number of devices used is `ulysses_degree * ring_degree`, arranged in a 2D grid where Ulysses and Ring groups are orthogonal (non-overlapping).
|
||||
Pass the [`ContextParallelConfig`] with both `ulysses_degree` and `ring_degree` set to bigger than 1 to [`~ModelMixin.enable_parallelism`].
|
||||
|
||||
```py
|
||||
pipeline.transformer.enable_parallelism(config=ContextParallelConfig(ulysses_degree=2, ring_degree=2))
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Unified Attention is to be used when there are enough devices to arrange in a 2D grid (at least 4 devices).
|
||||
|
||||
We ran a benchmark with Ulysess, Ring, and Unified Attention with [this script](https://github.com/huggingface/diffusers/pull/12693#issuecomment-3694727532) on a node of 4 H100 GPUs. The results are summarized as follows:
|
||||
|
||||
| CP Backend | Time / Iter (ms) | Steps / Sec | Peak Memory (GB) |
|
||||
|--------------------|------------------|-------------|------------------|
|
||||
| ulysses | 6670.789 | 7.50 | 33.85 |
|
||||
| ring | 13076.492 | 3.82 | 56.02 |
|
||||
| unified_balanced | 11068.705 | 4.52 | 33.85 |
|
||||
|
||||
From the above table, it's clear that Ulysses provides better throughput, but the number of devices it can use remains limited to the number of attention heads, a limitation that is solved by unified attention.
|
||||
|
||||
### parallel_config
|
||||
|
||||
Pass `parallel_config` during model initialization to enable context parallelism.
|
||||
|
||||
@@ -1929,8 +1929,6 @@ def main(args):
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
# Store vae config before potential deletion
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
@@ -1942,8 +1940,6 @@ def main(args):
|
||||
del vae
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
@@ -2113,13 +2109,13 @@ def main(args):
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae_scaling_factor
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -149,13 +149,13 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_separator' string. These should correspond to the order of the validation prompts.",
|
||||
help="One or more image path(s) that is used during validation to verify that the model is learning. Multiple validation paths should be separated by the '--validation_prompt_seperator' string. These should correspond to the order of the validation prompts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -140,7 +140,7 @@ def get_args():
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_separator' string.",
|
||||
help="One or more prompt(s) that is used during validation to verify that the model is learning. Multiple validation prompts should be separated by the '--validation_prompt_seperator' string.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt_separator",
|
||||
|
||||
@@ -21,8 +21,8 @@ from transformers import (
|
||||
BertModel,
|
||||
BertTokenizer,
|
||||
CLIPImageProcessor,
|
||||
MT5Tokenizer,
|
||||
T5EncoderModel,
|
||||
T5Tokenizer,
|
||||
)
|
||||
|
||||
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
The HunyuanDiT model designed by Tencent Hunyuan.
|
||||
text_encoder_2 (`T5EncoderModel`):
|
||||
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
|
||||
tokenizer_2 (`T5Tokenizer`):
|
||||
tokenizer_2 (`MT5Tokenizer`):
|
||||
The tokenizer for the mT5 embedder.
|
||||
scheduler ([`DDPMScheduler`]):
|
||||
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
|
||||
@@ -295,7 +295,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
requires_safety_checker: bool = True,
|
||||
text_encoder_2=T5EncoderModel,
|
||||
tokenizer_2=T5Tokenizer,
|
||||
tokenizer_2=MT5Tokenizer,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -1,844 +0,0 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoTokenizer, PreTrainedModel
|
||||
|
||||
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
|
||||
from diffusers.models.autoencoders import AutoencoderKL
|
||||
from diffusers.models.transformers import ZImageTransformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.z_image.pipeline_output import ZImagePipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import logging, replace_example_docstring
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from pipeline_z_image_differential_img2img import ZImageDifferentialImg2ImgPipeline
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = ZImageDifferentialImg2ImgPipeline.from_pretrained("Z-a-o/Z-Image-Turbo", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> init_image = load_image(
|
||||
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/input.jpg?raw=true",
|
||||
>>> )
|
||||
|
||||
>>> mask = load_image(
|
||||
>>> "https://github.com/exx8/differential-diffusion/blob/main/assets/map.jpg?raw=true",
|
||||
>>> )
|
||||
|
||||
>>> prompt = "painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art"
|
||||
|
||||
>>> image = pipe(
|
||||
... prompt,
|
||||
... image=init_image,
|
||||
... mask_image=mask,
|
||||
... strength=0.75,
|
||||
... num_inference_steps=9,
|
||||
... guidance_scale=0.0,
|
||||
... generator=torch.Generator("cuda").manual_seed(41),
|
||||
... ).images[0]
|
||||
>>> image.save("image.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
|
||||
def calculate_shift(
|
||||
image_seq_len,
|
||||
base_seq_len: int = 256,
|
||||
max_seq_len: int = 4096,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.15,
|
||||
):
|
||||
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
||||
b = base_shift - m * base_seq_len
|
||||
mu = image_seq_len * m + b
|
||||
return mu
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
||||
def retrieve_timesteps(
|
||||
scheduler,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
timesteps: Optional[List[int]] = None,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
||||
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
||||
|
||||
Args:
|
||||
scheduler (`SchedulerMixin`):
|
||||
The scheduler to get timesteps from.
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
||||
must be `None`.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
||||
`num_inference_steps` and `sigmas` must be `None`.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
||||
`num_inference_steps` and `timesteps` must be `None`.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
||||
second element is the number of inference steps.
|
||||
"""
|
||||
if timesteps is not None and sigmas is not None:
|
||||
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
||||
if timesteps is not None:
|
||||
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accepts_timesteps:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" timestep schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
elif sigmas is not None:
|
||||
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
||||
if not accept_sigmas:
|
||||
raise ValueError(
|
||||
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
||||
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
||||
)
|
||||
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
||||
timesteps = scheduler.timesteps
|
||||
return timesteps, num_inference_steps
|
||||
|
||||
|
||||
class ZImageDifferentialImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
|
||||
r"""
|
||||
The ZImage pipeline for image-to-image generation.
|
||||
|
||||
Args:
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
text_encoder ([`PreTrainedModel`]):
|
||||
A text encoder model to encode text prompts.
|
||||
tokenizer ([`AutoTokenizer`]):
|
||||
A tokenizer to tokenize text prompts.
|
||||
transformer ([`ZImageTransformer2DModel`]):
|
||||
A ZImage transformer model to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_optional_components = []
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: PreTrainedModel,
|
||||
tokenizer: AutoTokenizer,
|
||||
transformer: ZImageTransformer2DModel,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vae=vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
transformer=transformer,
|
||||
)
|
||||
self.vae_scale_factor = (
|
||||
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
||||
)
|
||||
latent_channels = self.vae.config.latent_channels if getattr(self, "vae", None) else 16
|
||||
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
|
||||
|
||||
self.mask_processor = VaeImageProcessor(
|
||||
vae_scale_factor=self.vae_scale_factor,
|
||||
vae_latent_channels=latent_channels,
|
||||
do_normalize=False,
|
||||
do_binarize=False,
|
||||
do_convert_grayscale=True,
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline.encode_prompt
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if negative_prompt is None:
|
||||
negative_prompt = ["" for _ in prompt]
|
||||
else:
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
assert len(prompt) == len(negative_prompt)
|
||||
negative_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
else:
|
||||
negative_prompt_embeds = []
|
||||
return prompt_embeds, negative_prompt_embeds
|
||||
|
||||
# Copied from diffusers.pipelines.z_image.pipeline_z_image.ZImagePipeline._encode_prompt
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
max_sequence_length: int = 512,
|
||||
) -> List[torch.FloatTensor]:
|
||||
device = device or self._execution_device
|
||||
|
||||
if prompt_embeds is not None:
|
||||
return prompt_embeds
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
|
||||
for i, prompt_item in enumerate(prompt):
|
||||
messages = [
|
||||
{"role": "user", "content": prompt_item},
|
||||
]
|
||||
prompt_item = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
enable_thinking=True,
|
||||
)
|
||||
prompt[i] = prompt_item
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_input_ids = text_inputs.input_ids.to(device)
|
||||
prompt_masks = text_inputs.attention_mask.to(device).bool()
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
input_ids=text_input_ids,
|
||||
attention_mask=prompt_masks,
|
||||
output_hidden_states=True,
|
||||
).hidden_states[-2]
|
||||
|
||||
embeddings_list = []
|
||||
|
||||
for i in range(len(prompt_embeds)):
|
||||
embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
|
||||
|
||||
return embeddings_list
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
|
||||
def get_timesteps(self, num_inference_steps, strength, device):
|
||||
# get the original timestep using init_timestep
|
||||
init_timestep = min(num_inference_steps * strength, num_inference_steps)
|
||||
|
||||
t_start = int(max(num_inference_steps - init_timestep, 0))
|
||||
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
||||
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
@staticmethod
|
||||
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
|
||||
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
|
||||
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
|
||||
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
|
||||
|
||||
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
|
||||
|
||||
latent_image_ids = latent_image_ids.reshape(
|
||||
latent_image_id_height * latent_image_id_width, latent_image_id_channels
|
||||
)
|
||||
|
||||
return latent_image_ids.to(device=device, dtype=dtype)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
image,
|
||||
timestep,
|
||||
batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
latents=None,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, num_channels_latents, height, width)
|
||||
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
|
||||
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
# Encode the input image
|
||||
image = image.to(device=device, dtype=dtype)
|
||||
if image.shape[1] != num_channels_latents:
|
||||
if isinstance(generator, list):
|
||||
image_latents = [
|
||||
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
|
||||
for i in range(image.shape[0])
|
||||
]
|
||||
image_latents = torch.cat(image_latents, dim=0)
|
||||
else:
|
||||
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
|
||||
|
||||
# Apply scaling (inverse of decoding: decode does latents/scaling_factor + shift_factor)
|
||||
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
else:
|
||||
image_latents = image
|
||||
|
||||
# Handle batch size expansion
|
||||
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
|
||||
additional_image_per_prompt = batch_size // image_latents.shape[0]
|
||||
image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
|
||||
elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
|
||||
raise ValueError(
|
||||
f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
|
||||
)
|
||||
|
||||
# Add noise using flow matching scale_noise
|
||||
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
|
||||
|
||||
return latents, noise, image_latents, latent_image_ids
|
||||
|
||||
def prepare_mask_latents(
|
||||
self,
|
||||
mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
dtype,
|
||||
device,
|
||||
generator,
|
||||
):
|
||||
height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
# resize the mask to latents shape as we concatenate the mask to the latents
|
||||
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
|
||||
# and half precision
|
||||
mask = torch.nn.functional.interpolate(mask, size=(height, width))
|
||||
mask = mask.to(device=device, dtype=dtype)
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
masked_image = masked_image.to(device=device, dtype=dtype)
|
||||
|
||||
if masked_image.shape[1] == 16:
|
||||
masked_image_latents = masked_image
|
||||
else:
|
||||
masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
|
||||
|
||||
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
|
||||
|
||||
# duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
|
||||
if mask.shape[0] < batch_size:
|
||||
if not batch_size % mask.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
|
||||
f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
|
||||
" of masks that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
|
||||
if masked_image_latents.shape[0] < batch_size:
|
||||
if not batch_size % masked_image_latents.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"The passed images and the required batch size don't match. Images are supposed to be duplicated"
|
||||
f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
|
||||
" Make sure the number of images that you pass is divisible by the total requested batch size."
|
||||
)
|
||||
masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
|
||||
|
||||
# aligning device to prevent device errors when concating it with the latent model input
|
||||
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
|
||||
|
||||
return mask, masked_image_latents
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1
|
||||
|
||||
@property
|
||||
def joint_attention_kwargs(self):
|
||||
return self._joint_attention_kwargs
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
@property
|
||||
def interrupt(self):
|
||||
return self._interrupt
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
image: PipelineImageInput = None,
|
||||
mask_image: PipelineImageInput = None,
|
||||
strength: float = 0.6,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
sigmas: Optional[List[float]] = None,
|
||||
guidance_scale: float = 5.0,
|
||||
cfg_normalization: bool = False,
|
||||
cfg_truncation: float = 1.0,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for image-to-image generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
|
||||
numpy array and pytorch tensor, the expected value range is between `[0, 1]`. If it's a tensor or a
|
||||
list of tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or
|
||||
a list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`.
|
||||
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
|
||||
`Image`, numpy array or tensor representing an image batch to mask `image`. Black pixels in the mask
|
||||
are repainted while white pixels are preserved. If `mask_image` is a PIL image, it is converted to a
|
||||
single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one
|
||||
color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B,
|
||||
H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W,
|
||||
1)`, or `(H, W)`.
|
||||
strength (`float`, *optional*, defaults to 0.6):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
starting point and more noise is added the higher the `strength`. The number of denoising steps depends
|
||||
on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
|
||||
process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
|
||||
essentially ignores `image`.
|
||||
height (`int`, *optional*, defaults to 1024):
|
||||
The height in pixels of the generated image. If not provided, uses the input image height.
|
||||
width (`int`, *optional*, defaults to 1024):
|
||||
The width in pixels of the generated image. If not provided, uses the input image width.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 5.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
cfg_normalization (`bool`, *optional*, defaults to False):
|
||||
Whether to apply configuration normalization.
|
||||
cfg_truncation (`float`, *optional*, defaults to 1.0):
|
||||
The truncation value for configuration.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`List[torch.FloatTensor]`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.ZImagePipelineOutput`] instead of a plain
|
||||
tuple.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
||||
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
||||
`._callback_tensor_inputs` attribute of your pipeline class.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
Maximum sequence length to use with the `prompt`.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.z_image.ZImagePipelineOutput`] or `tuple`: [`~pipelines.z_image.ZImagePipelineOutput`] if
|
||||
`return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
# 1. Check inputs and validate strength
|
||||
if strength < 0 or strength > 1:
|
||||
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
|
||||
|
||||
# 2. Preprocess image
|
||||
init_image = self.image_processor.preprocess(image)
|
||||
init_image = init_image.to(dtype=torch.float32)
|
||||
|
||||
# Get dimensions from the preprocessed image if not specified
|
||||
if height is None:
|
||||
height = init_image.shape[-2]
|
||||
if width is None:
|
||||
width = init_image.shape[-1]
|
||||
|
||||
vae_scale = self.vae_scale_factor * 2
|
||||
if height % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Height must be divisible by {vae_scale} (got {height}). "
|
||||
f"Please adjust the height to a multiple of {vae_scale}."
|
||||
)
|
||||
if width % vae_scale != 0:
|
||||
raise ValueError(
|
||||
f"Width must be divisible by {vae_scale} (got {width}). "
|
||||
f"Please adjust the width to a multiple of {vae_scale}."
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._joint_attention_kwargs = joint_attention_kwargs
|
||||
self._interrupt = False
|
||||
self._cfg_normalization = cfg_normalization
|
||||
self._cfg_truncation = cfg_truncation
|
||||
|
||||
# 3. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = len(prompt_embeds)
|
||||
|
||||
# If prompt_embeds is provided and prompt is None, skip encoding
|
||||
if prompt_embeds is not None and prompt is None:
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided without `prompt`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
else:
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
device=device,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
num_channels_latents = self.transformer.in_channels
|
||||
|
||||
# Repeat prompt_embeds for num_images_per_prompt
|
||||
if num_images_per_prompt > 1:
|
||||
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
if self.do_classifier_free_guidance and negative_prompt_embeds:
|
||||
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
|
||||
|
||||
actual_batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
# Calculate latent dimensions for image_seq_len
|
||||
latent_height = 2 * (int(height) // (self.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(width) // (self.vae_scale_factor * 2))
|
||||
image_seq_len = (latent_height // 2) * (latent_width // 2)
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(
|
||||
image_seq_len,
|
||||
self.scheduler.config.get("base_image_seq_len", 256),
|
||||
self.scheduler.config.get("max_image_seq_len", 4096),
|
||||
self.scheduler.config.get("base_shift", 0.5),
|
||||
self.scheduler.config.get("max_shift", 1.15),
|
||||
)
|
||||
self.scheduler.sigma_min = 0.0
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler,
|
||||
num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
**scheduler_kwargs,
|
||||
)
|
||||
|
||||
# 6. Adjust timesteps based on strength
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
|
||||
if num_inference_steps < 1:
|
||||
raise ValueError(
|
||||
f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline "
|
||||
f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
|
||||
)
|
||||
latent_timestep = timesteps[:1].repeat(actual_batch_size)
|
||||
|
||||
# 7. Prepare latents from image
|
||||
latents, noise, original_image_latents, latent_image_ids = self.prepare_latents(
|
||||
init_image,
|
||||
latent_timestep,
|
||||
actual_batch_size,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
resize_mode = "default"
|
||||
crops_coords = None
|
||||
|
||||
# start diff diff preparation
|
||||
original_mask = self.mask_processor.preprocess(
|
||||
mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
|
||||
)
|
||||
|
||||
masked_image = init_image * original_mask
|
||||
original_mask, _ = self.prepare_mask_latents(
|
||||
original_mask,
|
||||
masked_image,
|
||||
batch_size,
|
||||
num_images_per_prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds[0].dtype,
|
||||
device,
|
||||
generator,
|
||||
)
|
||||
mask_thresholds = torch.arange(num_inference_steps, dtype=original_mask.dtype) / num_inference_steps
|
||||
mask_thresholds = mask_thresholds.reshape(-1, 1, 1, 1).to(device)
|
||||
masks = original_mask > mask_thresholds
|
||||
# end diff diff preparation
|
||||
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 8. Denoising loop
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0])
|
||||
timestep = (1000 - timestep) / 1000
|
||||
# Normalized time for time-aware config (0 at start, 1 at end)
|
||||
t_norm = timestep[0].item()
|
||||
|
||||
# Handle cfg truncation
|
||||
current_guidance_scale = self.guidance_scale
|
||||
if (
|
||||
self.do_classifier_free_guidance
|
||||
and self._cfg_truncation is not None
|
||||
and float(self._cfg_truncation) <= 1
|
||||
):
|
||||
if t_norm > self._cfg_truncation:
|
||||
current_guidance_scale = 0.0
|
||||
|
||||
# Run CFG only if configured AND scale is non-zero
|
||||
apply_cfg = self.do_classifier_free_guidance and current_guidance_scale > 0
|
||||
|
||||
if apply_cfg:
|
||||
latents_typed = latents.to(self.transformer.dtype)
|
||||
latent_model_input = latents_typed.repeat(2, 1, 1, 1)
|
||||
prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds
|
||||
timestep_model_input = timestep.repeat(2)
|
||||
else:
|
||||
latent_model_input = latents.to(self.transformer.dtype)
|
||||
prompt_embeds_model_input = prompt_embeds
|
||||
timestep_model_input = timestep
|
||||
|
||||
latent_model_input = latent_model_input.unsqueeze(2)
|
||||
latent_model_input_list = list(latent_model_input.unbind(dim=0))
|
||||
|
||||
model_out_list = self.transformer(
|
||||
latent_model_input_list,
|
||||
timestep_model_input,
|
||||
prompt_embeds_model_input,
|
||||
)[0]
|
||||
|
||||
if apply_cfg:
|
||||
# Perform CFG
|
||||
pos_out = model_out_list[:actual_batch_size]
|
||||
neg_out = model_out_list[actual_batch_size:]
|
||||
|
||||
noise_pred = []
|
||||
for j in range(actual_batch_size):
|
||||
pos = pos_out[j].float()
|
||||
neg = neg_out[j].float()
|
||||
|
||||
pred = pos + current_guidance_scale * (pos - neg)
|
||||
|
||||
# Renormalization
|
||||
if self._cfg_normalization and float(self._cfg_normalization) > 0.0:
|
||||
ori_pos_norm = torch.linalg.vector_norm(pos)
|
||||
new_pos_norm = torch.linalg.vector_norm(pred)
|
||||
max_new_norm = ori_pos_norm * float(self._cfg_normalization)
|
||||
if new_pos_norm > max_new_norm:
|
||||
pred = pred * (max_new_norm / new_pos_norm)
|
||||
|
||||
noise_pred.append(pred)
|
||||
|
||||
noise_pred = torch.stack(noise_pred, dim=0)
|
||||
else:
|
||||
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
|
||||
|
||||
noise_pred = noise_pred.squeeze(2)
|
||||
noise_pred = -noise_pred
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
|
||||
assert latents.dtype == torch.float32
|
||||
|
||||
# start diff diff
|
||||
image_latent = original_image_latents
|
||||
latents_dtype = latents.dtype
|
||||
if i < len(timesteps) - 1:
|
||||
noise_timestep = timesteps[i + 1]
|
||||
image_latent = self.scheduler.scale_noise(
|
||||
original_image_latents, torch.tensor([noise_timestep]), noise
|
||||
)
|
||||
|
||||
mask = masks[i].to(latents_dtype)
|
||||
latents = image_latent * mask + latents * (1 - mask)
|
||||
# end diff diff
|
||||
|
||||
if latents.dtype != latents_dtype:
|
||||
if torch.backends.mps.is_available():
|
||||
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
|
||||
latents = latents.to(latents_dtype)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
else:
|
||||
latents = latents.to(self.vae.dtype)
|
||||
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
|
||||
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ZImagePipelineOutput(images=image)
|
||||
@@ -98,9 +98,6 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### FSDP Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, it distributes the memory cost across multiple nodes.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
@@ -169,26 +166,6 @@ To better track our training experiments, we're using the following flags in the
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### FSDP on the transformer
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
|
||||
```shell
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: HYBRID_SHARD
|
||||
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_sync_module_states: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
@@ -111,25 +111,6 @@ To better track our training experiments, we're using the following flags in the
|
||||
|
||||
## Notes
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Additional CLI arguments
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
|
||||
@@ -44,7 +44,6 @@ import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -76,16 +75,13 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -729,7 +722,6 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1227,11 +1219,7 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1275,42 +1263,17 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1322,20 +1285,13 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1551,21 +1507,6 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1595,8 +1536,6 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1838,7 +1777,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1897,41 +1836,15 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -43,7 +43,6 @@ import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -75,16 +74,13 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -97,9 +93,6 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -346,7 +339,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -698,7 +691,6 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -835,28 +827,15 @@ class DreamBoothDataset(Dataset):
|
||||
dest_image = self.cond_images[i]
|
||||
image_width, image_height = dest_image.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
image_width, image_height = dest_image.size
|
||||
|
||||
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
image_processor = Flux2ImageProcessor()
|
||||
dest_image = image_processor.preprocess(
|
||||
dest_image = Flux2ImageProcessor.image_processor.preprocess(
|
||||
dest_image, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
# Convert back to PIL
|
||||
dest_image = dest_image.squeeze(0)
|
||||
if dest_image.min() < 0:
|
||||
dest_image = (dest_image + 1) / 2
|
||||
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
|
||||
|
||||
if dest_image.shape[0] == 1:
|
||||
# Gray scale image
|
||||
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
|
||||
else:
|
||||
# RGB scale image: (C, H, W) -> (H, W, C)
|
||||
dest_image = TF.to_pil_image(dest_image)
|
||||
|
||||
dest_image = exif_transpose(dest_image)
|
||||
if not dest_image.mode == "RGB":
|
||||
@@ -1177,11 +1156,7 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
|
||||
is_fsdp = getattr(accelerator.state, "fsdp_plugin", None) is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1225,42 +1200,17 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1272,20 +1222,13 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1476,9 +1419,9 @@ def main(args):
|
||||
args.instance_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.validation_prompt is not None:
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.remote_text_encoder:
|
||||
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
|
||||
else:
|
||||
@@ -1487,21 +1430,6 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1533,8 +1461,6 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1695,13 +1621,9 @@ def main(args):
|
||||
cond_model_input = (cond_model_input - latents_bn_mean) / latents_bn_std
|
||||
|
||||
model_input_ids = Flux2Pipeline._prepare_latent_ids(model_input).to(device=model_input.device)
|
||||
cond_model_input_list = [cond_model_input[i].unsqueeze(0) for i in range(cond_model_input.shape[0])]
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input_list).to(
|
||||
cond_model_input_ids = Flux2Pipeline._prepare_image_ids(cond_model_input).to(
|
||||
device=cond_model_input.device
|
||||
)
|
||||
cond_model_input_ids = cond_model_input_ids.view(
|
||||
cond_model_input.shape[0], -1, model_input_ids.shape[-1]
|
||||
)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
@@ -1728,9 +1650,6 @@ def main(args):
|
||||
packed_noisy_model_input = Flux2Pipeline._pack_latents(noisy_model_input)
|
||||
packed_cond_model_input = Flux2Pipeline._pack_latents(cond_model_input)
|
||||
|
||||
orig_input_shape = packed_noisy_model_input.shape
|
||||
orig_input_ids_shape = model_input_ids.shape
|
||||
|
||||
# concatenate the model inputs with the cond inputs
|
||||
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_model_input], dim=1)
|
||||
model_input_ids = torch.cat([model_input_ids, cond_model_input_ids], dim=1)
|
||||
@@ -1749,8 +1668,7 @@ def main(args):
|
||||
img_ids=model_input_ids, # B, image_seq_len, 4
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = model_pred[:, : orig_input_shape[1], :]
|
||||
model_input_ids = model_input_ids[:, : orig_input_ids_shape[1], :]
|
||||
model_pred = model_pred[:, : packed_noisy_model_input.size(1) :]
|
||||
|
||||
model_pred = Flux2Pipeline._unpack_latents_with_ids(model_pred, model_input_ids)
|
||||
|
||||
@@ -1782,7 +1700,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if accelerator.is_main_process:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1841,41 +1759,15 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -1513,12 +1513,14 @@ def main(args):
|
||||
height=model_input.shape[3],
|
||||
width=model_input.shape[4],
|
||||
)
|
||||
print(f"{prompt_embeds_mask.sum(dim=1).tolist()=}")
|
||||
model_pred = transformer(
|
||||
hidden_states=packed_noisy_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_mask=prompt_embeds_mask,
|
||||
timestep=timesteps / 1000,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=prompt_embeds_mask.sum(dim=1).tolist(),
|
||||
return_dict=False,
|
||||
)[0]
|
||||
model_pred = QwenImagePipeline._unpack_latents(
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
|
||||
|
||||
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
|
||||
|
||||
## Overview
|
||||
|
||||
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
|
||||
|
||||
- Loss of fine details in generated images
|
||||
- Inconsistent image quality
|
||||
- Structural artifacts
|
||||
- Reduced sharpness and realism
|
||||
|
||||
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
|
||||
|
||||
- Improved image quality and consistency (6-20% FID improvement)
|
||||
- Better preservation of fine details
|
||||
- More stable training, especially at high noise levels
|
||||
- Better handling of structural information
|
||||
- Sharper and more realistic textures
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The LPL implementation follows the paper's methodology and includes several key features:
|
||||
|
||||
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
|
||||
- Middle block features
|
||||
- Up block features (configurable number of blocks)
|
||||
- Proper gradient checkpointing for memory efficiency
|
||||
- Features are extracted only for timesteps below the threshold (high SNR)
|
||||
|
||||
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
|
||||
- `default`: Normalize each feature map independently
|
||||
- `shared`: Cross-normalize features using target statistics (recommended)
|
||||
- `batch`: Batch-wise normalization
|
||||
|
||||
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
|
||||
- Quantile-based filtering (2% quantiles)
|
||||
- Morphological operations (opening/closing)
|
||||
- Adaptive thresholding based on standard deviation
|
||||
|
||||
4. **Loss Types**:
|
||||
- MSE loss (default)
|
||||
- L1 loss
|
||||
- Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
## Usage
|
||||
|
||||
To use LPL in your training, add the following arguments to your training command:
|
||||
|
||||
```bash
|
||||
python examples/research_projects/lpl/train_sdxl_lpl.py \
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
|
||||
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
|
||||
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
|
||||
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
|
||||
--lpl_pow_law \ # Use power law weighting for layers
|
||||
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
|
||||
--lpl_remove_outliers \ # Remove outliers in feature maps
|
||||
--lpl_scale \ # Scale LPL loss by noise level weights
|
||||
--lpl_start 0 \ # Step to start applying LPL
|
||||
# ... other training arguments ...
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
|
||||
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
|
||||
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
|
||||
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
|
||||
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
|
||||
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
|
||||
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
|
||||
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
|
||||
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Starting Point** (based on paper results):
|
||||
```bash
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \
|
||||
--lpl_t_threshold 200 \
|
||||
--lpl_loss_type mse \
|
||||
--lpl_norm_type shared \
|
||||
--lpl_pow_law \
|
||||
--lpl_num_blocks 4 \
|
||||
--lpl_remove_outliers \
|
||||
--lpl_scale
|
||||
```
|
||||
|
||||
2. **Memory Efficiency**:
|
||||
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
|
||||
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
|
||||
- Consider using `--lpl_scale` to focus on more important timesteps
|
||||
- Features are extracted only for timesteps below threshold to save memory
|
||||
|
||||
3. **Quality vs Speed**:
|
||||
- Higher `lpl_weight` (1.0-2.0) for better quality
|
||||
- Lower `lpl_t_threshold` (100-200) for faster training
|
||||
- Use `lpl_remove_outliers` for more stable training
|
||||
- `lpl_norm_type shared` provides best quality/speed trade-off
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
The LPL implementation extracts features from the VAE decoder in the following order:
|
||||
1. Middle block output
|
||||
2. Up block outputs (configurable number of blocks)
|
||||
|
||||
Each feature map is processed with:
|
||||
1. Optional outlier removal (2% quantiles, morphological operations)
|
||||
2. Feature normalization (shared statistics recommended)
|
||||
3. Loss calculation (MSE or L1)
|
||||
4. Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
### Loss Calculation
|
||||
|
||||
For each feature map:
|
||||
1. Features are normalized according to the chosen strategy
|
||||
2. Loss is calculated between normalized features
|
||||
3. Outliers are masked out (if enabled)
|
||||
4. Loss is weighted by layer depth (if power law enabled)
|
||||
5. Final loss is averaged across all layers
|
||||
|
||||
### Memory Considerations
|
||||
|
||||
- Gradient checkpointing is used by default
|
||||
- Features are extracted only for timesteps below the threshold
|
||||
- Outlier removal is done in-place to save memory
|
||||
- Feature normalization is done efficiently using vectorized operations
|
||||
- Memory usage scales linearly with number of blocks used
|
||||
|
||||
## Results
|
||||
|
||||
Based on the paper's findings, LPL provides:
|
||||
- 6-20% improvement in FID scores
|
||||
- Better preservation of fine details
|
||||
- More realistic textures and structures
|
||||
- Improved consistency across different resolutions
|
||||
- Better performance on both small and large datasets
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this implementation in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{berrada2025boosting,
|
||||
title={Boosting Latent Diffusion with Perceptual Objectives},
|
||||
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
|
||||
booktitle={The Thirteenth International Conference on Learning Representations},
|
||||
year={2025},
|
||||
url={https://openreview.net/forum?id=y4DtzADzd1}
|
||||
}
|
||||
```
|
||||
@@ -1,215 +0,0 @@
|
||||
# Copyright 2025 Berrada et al.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
||||
return in_feat / (norm_factor + eps)
|
||||
|
||||
|
||||
def cross_normalize(input, target, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
|
||||
return input / (norm_factor + eps), target / (norm_factor + eps)
|
||||
|
||||
|
||||
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
|
||||
opening = int(np.ceil(opening / down_f))
|
||||
closing = int(np.ceil(closing / down_f))
|
||||
if opening == 2:
|
||||
opening = 3
|
||||
if closing == 2:
|
||||
closing = 1
|
||||
|
||||
# replace quantile with kth value here.
|
||||
feat_flat = feat.flatten(-2, -1)
|
||||
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
|
||||
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
|
||||
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
|
||||
|
||||
m = 2 * feat_flat.std(-1)[..., None, None].detach()
|
||||
mask = (q1 - m < feat) * (feat < q2 + m)
|
||||
|
||||
# dilate the mask.
|
||||
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
|
||||
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
|
||||
feat = feat * mask
|
||||
return mask, feat
|
||||
|
||||
|
||||
class LatentPerceptualLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
loss_type="mse",
|
||||
grad_ckpt=True,
|
||||
pow_law=False,
|
||||
norm_type="default",
|
||||
num_mid_blocks=4,
|
||||
feature_type="feature",
|
||||
remove_outliers=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.decoder = self.vae.decoder
|
||||
# Store scaling factors as tensors on the correct device
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
# Get scaling factors with proper defaults and handle None values
|
||||
scale_factor = getattr(self.vae.config, "scaling_factor", None)
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", None)
|
||||
|
||||
# Convert to tensors with proper defaults
|
||||
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
|
||||
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
|
||||
|
||||
self.gradient_checkpointing = grad_ckpt
|
||||
self.pow_law = pow_law
|
||||
self.norm_type = norm_type.lower()
|
||||
self.outlier_mask = remove_outliers
|
||||
self.last_feature_stats = [] # Store feature statistics for logging
|
||||
|
||||
assert feature_type in ["feature", "image"]
|
||||
self.feature_type = feature_type
|
||||
|
||||
assert self.norm_type in ["default", "shared", "batch"]
|
||||
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
|
||||
self.n_blocks = num_mid_blocks
|
||||
|
||||
assert loss_type in ["mse", "l1"]
|
||||
if loss_type == "mse":
|
||||
self.loss_fn = nn.MSELoss(reduction="none")
|
||||
elif loss_type == "l1":
|
||||
self.loss_fn = nn.L1Loss(reduction="none")
|
||||
|
||||
def get_features(self, z, latent_embeds=None, disable_grads=False):
|
||||
with torch.set_grad_enabled(not disable_grads):
|
||||
if self.gradient_checkpointing and not disable_grads:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.decoder.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
features.append(sample)
|
||||
return features
|
||||
else:
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.decoder.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
features.append(sample)
|
||||
return features
|
||||
|
||||
def get_loss(self, input, target, get_hist=False):
|
||||
if self.feature_type == "feature":
|
||||
inp_f = self.get_features(self.shift + input / self.scale)
|
||||
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
|
||||
losses = []
|
||||
self.last_feature_stats = [] # Reset feature stats
|
||||
|
||||
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
|
||||
my = torch.ones_like(y).bool()
|
||||
outlier_ratio = 0.0
|
||||
|
||||
if self.outlier_mask:
|
||||
with torch.no_grad():
|
||||
if i == 2:
|
||||
my, y = remove_outliers(y, down_f=2)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
elif i in [3, 4, 5]:
|
||||
my, y = remove_outliers(y, down_f=1)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
|
||||
# Store feature statistics before normalization
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": y.mean().item(),
|
||||
"std": y.std().item(),
|
||||
"outlier_ratio": outlier_ratio,
|
||||
}
|
||||
self.last_feature_stats.append(stats)
|
||||
|
||||
# normalize feature tensors
|
||||
if self.norm_type == "default":
|
||||
x = normalize_tensor(x)
|
||||
y = normalize_tensor(y)
|
||||
elif self.norm_type == "shared":
|
||||
x, y = cross_normalize(x, y, eps=1e-6)
|
||||
|
||||
term_loss = self.loss_fn(x, y) * my
|
||||
# reduce loss term
|
||||
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
|
||||
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
|
||||
losses.append(term_loss.mean((1,)))
|
||||
|
||||
if get_hist:
|
||||
return losses
|
||||
else:
|
||||
loss = sum(losses)
|
||||
return loss / len(inp_f)
|
||||
elif self.feature_type == "image":
|
||||
inp_f = self.vae.decode(input / self.scale).sample
|
||||
tar_f = self.vae.decode(target / self.scale).sample
|
||||
return F.mse_loss(inp_f, tar_f)
|
||||
|
||||
def get_first_conv(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
return sample
|
||||
|
||||
def get_first_block(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
sample = self.decoder.mid_block(sample)
|
||||
for resnet in self.decoder.up_blocks[0].resnets:
|
||||
sample = resnet(sample, None)
|
||||
return sample
|
||||
|
||||
def get_first_layer(self, input, target, target_layer="conv"):
|
||||
if target_layer == "conv":
|
||||
feat_in = self.get_first_conv(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_conv(target)
|
||||
else:
|
||||
feat_in = self.get_first_block(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_block(target)
|
||||
|
||||
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
|
||||
|
||||
return F.mse_loss(feat_in, feat_tar, reduction="mean")
|
||||
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@ The `train_text_to_image.py` script shows how to fine-tune stable diffusion mode
|
||||
|
||||
___Note___:
|
||||
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparameters to get the best result on your dataset.___
|
||||
___This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.___
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
@@ -18,7 +18,7 @@ cc.initialize_cache("/tmp/sdxl_cache")
|
||||
NUM_DEVICES = jax.device_count()
|
||||
|
||||
# 1. Let's start by downloading the model and loading it into our pipeline class
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned separately and
|
||||
# Adhering to JAX's functional approach, the model's parameters are returned seperatetely and
|
||||
# will have to be passed to the pipeline during inference
|
||||
pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", revision="refs/pr/95", split_head_dim=True
|
||||
|
||||
@@ -7,12 +7,16 @@ import torch
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
@@ -34,40 +38,23 @@ class RequestScopedPipeline:
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
self.transformer = getattr(pipeline, "transformer", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._vae_lock = threading.Lock()
|
||||
self._image_lock = threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _detect_kernel_pipeline(self, pipeline) -> bool:
|
||||
kernel_indicators = [
|
||||
"text_encoding_cache",
|
||||
"memory_manager",
|
||||
"enable_optimizations",
|
||||
"_create_request_context",
|
||||
"get_optimization_stats",
|
||||
]
|
||||
|
||||
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
@@ -83,21 +70,11 @@ class RequestScopedPipeline:
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
try:
|
||||
if hasattr(wrapped_scheduler, "scheduler"):
|
||||
try:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception:
|
||||
return wrapped_scheduler
|
||||
else:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
|
||||
)
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
@@ -109,7 +86,6 @@ class RequestScopedPipeline:
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
@@ -117,7 +93,6 @@ class RequestScopedPipeline:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
@@ -125,9 +100,11 @@ class RequestScopedPipeline:
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
@@ -228,9 +205,6 @@ class RequestScopedPipeline:
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def _should_wrap_tokenizers(self) -> bool:
|
||||
return True
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
@@ -240,25 +214,6 @@ class RequestScopedPipeline:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(local_pipe, "vae")
|
||||
and local_pipe.vae is not None
|
||||
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
|
||||
):
|
||||
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
|
||||
|
||||
if (
|
||||
hasattr(local_pipe, "image_processor")
|
||||
and local_pipe.image_processor is not None
|
||||
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
|
||||
):
|
||||
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
|
||||
local_pipe.image_processor, self._image_lock
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not wrap vae/image_processor: {e}")
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
@@ -276,42 +231,47 @@ class RequestScopedPipeline:
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
original_tokenizers = {}
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
|
||||
if self._should_wrap_tokenizers():
|
||||
try:
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
if not isinstance(tok, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[name] = tok
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
|
||||
setattr(local_pipe, name, wrapped_tokenizer)
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
if not isinstance(val, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[f"components[{key}]"] = val
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
|
||||
local_pipe.components[key] = wrapped_tokenizer
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
@@ -319,18 +279,18 @@ class RequestScopedPipeline:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in original_tokenizers.items():
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
local_pipe.components[key] = tok
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring original tokenizers: {e}")
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
|
||||
@@ -1,86 +0,0 @@
|
||||
class ThreadSafeTokenizerWrapper:
|
||||
def __init__(self, tokenizer, lock):
|
||||
self._tokenizer = tokenizer
|
||||
self._lock = lock
|
||||
|
||||
self._thread_safe_methods = {
|
||||
"__call__",
|
||||
"encode",
|
||||
"decode",
|
||||
"tokenize",
|
||||
"encode_plus",
|
||||
"batch_encode_plus",
|
||||
"batch_decode",
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._tokenizer, name)
|
||||
|
||||
if name in self._thread_safe_methods and callable(attr):
|
||||
|
||||
def wrapped_method(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
return attr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return self._tokenizer(*args, **kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._tokenizer, name, value)
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._tokenizer)
|
||||
|
||||
|
||||
class ThreadSafeVAEWrapper:
|
||||
def __init__(self, vae, lock):
|
||||
self._vae = vae
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._vae, name)
|
||||
if name in {"decode", "encode", "forward"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._vae, name, value)
|
||||
|
||||
|
||||
class ThreadSafeImageProcessorWrapper:
|
||||
def __init__(self, proc, lock):
|
||||
self._proc = proc
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._proc, name)
|
||||
if name in {"postprocess", "preprocess"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._proc, name, value)
|
||||
@@ -1,94 +1,11 @@
|
||||
"""
|
||||
# Cosmos 2 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2-2B-Text2Image
|
||||
```
|
||||
|
||||
convert checkpoint
|
||||
```bash
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2-2B-Text2Image/snapshots/acdb5fde992a73ef0355f287977d002cbfd127e0/model.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--transformer_type Cosmos-2.0-Diffusion-2B-Text2Image \
|
||||
--text_encoder_path google-t5/t5-11b \
|
||||
--tokenizer_path google-t5/t5-11b \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/cosmos-p2-t2i-2b \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
# Cosmos 2.5 Predict
|
||||
|
||||
Download checkpoint
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-2B
|
||||
```
|
||||
|
||||
Convert checkpoint
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
## 14B
|
||||
|
||||
```bash
|
||||
hf download nvidia/Cosmos-Predict2.5-14B
|
||||
```
|
||||
|
||||
```bash
|
||||
# pre-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
|
||||
--save_pipeline
|
||||
|
||||
# post-trained
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Predict-Base-14B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
|
||||
--save_pipeline
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration, T5EncoderModel, T5TokenizerFast
|
||||
from transformers import T5EncoderModel, T5TokenizerFast
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLCosmos,
|
||||
@@ -100,9 +17,7 @@ from diffusers import (
|
||||
CosmosVideoToWorldPipeline,
|
||||
EDMEulerScheduler,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
UniPCMultistepScheduler,
|
||||
)
|
||||
from diffusers.pipelines.cosmos.pipeline_cosmos2_5_predict import Cosmos2_5_PredictBasePipeline
|
||||
|
||||
|
||||
def remove_keys_(key: str, state_dict: Dict[str, Any]):
|
||||
@@ -318,44 +233,6 @@ TRANSFORMER_CONFIGS = {
|
||||
"concat_padding_mask": True,
|
||||
"extra_pos_embed_type": None,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-2B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 16,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 28,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
"Cosmos-2.5-Predict-Base-14B": {
|
||||
"in_channels": 16 + 1,
|
||||
"out_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"attention_head_dim": 128,
|
||||
"num_layers": 36,
|
||||
"mlp_ratio": 4.0,
|
||||
"text_embed_dim": 1024,
|
||||
"adaln_lora_dim": 256,
|
||||
"max_size": (128, 240, 240),
|
||||
"patch_size": (1, 2, 2),
|
||||
"rope_scale": (1.0, 3.0, 3.0),
|
||||
"concat_padding_mask": True,
|
||||
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
|
||||
"extra_pos_embed_type": None,
|
||||
"use_crossattn_projection": True,
|
||||
"crossattn_proj_in_channels": 100352,
|
||||
"encoder_hidden_states_channels": 1024,
|
||||
},
|
||||
}
|
||||
|
||||
VAE_KEYS_RENAME_DICT = {
|
||||
@@ -457,9 +334,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
elif "Cosmos-2.0" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
elif "Cosmos-2.5" in transformer_type:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
else:
|
||||
assert False
|
||||
|
||||
@@ -473,7 +347,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
print(key, "->", new_key, flush=True)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
@@ -482,21 +355,6 @@ def convert_transformer(transformer_type: str, ckpt_path: str, weights_only: boo
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
expected_keys = set(transformer.state_dict().keys())
|
||||
mapped_keys = set(original_state_dict.keys())
|
||||
missing_keys = expected_keys - mapped_keys
|
||||
unexpected_keys = mapped_keys - expected_keys
|
||||
if missing_keys:
|
||||
print(f"ERROR: missing keys ({len(missing_keys)} from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in missing_keys:
|
||||
print(k)
|
||||
sys.exit(1)
|
||||
if unexpected_keys:
|
||||
print(f"ERROR: unexpected keys ({len(unexpected_keys)}) from state_dict:", flush=True, file=sys.stderr)
|
||||
for k in unexpected_keys:
|
||||
print(k)
|
||||
sys.exit(2)
|
||||
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
@@ -586,34 +444,6 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def save_pipeline_cosmos2_5(args, transformer, vae):
|
||||
text_encoder_path = args.text_encoder_path or "nvidia/Cosmos-Reason1-7B"
|
||||
tokenizer_path = args.tokenizer_path or "Qwen/Qwen2.5-VL-7B-Instruct"
|
||||
|
||||
text_encoder = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
||||
text_encoder_path, torch_dtype="auto", device_map="cpu"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
||||
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
use_karras_sigmas=True,
|
||||
use_flow_sigmas=True,
|
||||
prediction_type="flow_prediction",
|
||||
sigma_max=200.0,
|
||||
sigma_min=0.01,
|
||||
)
|
||||
|
||||
pipe = Cosmos2_5_PredictBasePipeline(
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
transformer=transformer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
safety_checker=lambda *args, **kwargs: None,
|
||||
)
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--transformer_type", type=str, default=None, choices=list(TRANSFORMER_CONFIGS.keys()))
|
||||
@@ -621,10 +451,10 @@ def get_args():
|
||||
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_type", type=str, default="wan2.1", choices=["wan2.1", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
"--vae_type", type=str, default=None, choices=["none", *list(VAE_CONFIGS.keys())], help="Type of VAE"
|
||||
)
|
||||
parser.add_argument("--text_encoder_path", type=str, default=None)
|
||||
parser.add_argument("--tokenizer_path", type=str, default=None)
|
||||
parser.add_argument("--text_encoder_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--tokenizer_path", type=str, default="google-t5/t5-11b")
|
||||
parser.add_argument("--save_pipeline", action="store_true")
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
|
||||
@@ -647,6 +477,8 @@ if __name__ == "__main__":
|
||||
if args.save_pipeline:
|
||||
assert args.transformer_ckpt_path is not None
|
||||
assert args.vae_type is not None
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
|
||||
if args.transformer_ckpt_path is not None:
|
||||
weights_only = "Cosmos-1.0" in args.transformer_type
|
||||
@@ -658,26 +490,17 @@ if __name__ == "__main__":
|
||||
if args.vae_type is not None:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
vae = convert_vae(args.vae_type)
|
||||
elif "Cosmos-2.0" in args.transformer_type or "Cosmos-2.5" in args.transformer_type:
|
||||
else:
|
||||
vae = AutoencoderKLWan.from_pretrained(
|
||||
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32
|
||||
)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
|
||||
if not args.save_pipeline:
|
||||
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.save_pipeline:
|
||||
if "Cosmos-1.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_1_0(args, transformer, vae)
|
||||
elif "Cosmos-2.0" in args.transformer_type:
|
||||
assert args.text_encoder_path is not None
|
||||
assert args.tokenizer_path is not None
|
||||
save_pipeline_cosmos_2_0(args, transformer, vae)
|
||||
elif "Cosmos-2.5" in args.transformer_type:
|
||||
save_pipeline_cosmos2_5(args, transformer, vae)
|
||||
else:
|
||||
raise AssertionError(f"{args.transformer_type} not supported")
|
||||
assert False
|
||||
|
||||
@@ -1,886 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upsample_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any(
|
||||
[
|
||||
args.vae,
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt,
|
||||
latent_upsampler_config,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
2
setup.py
2
setup.py
@@ -274,7 +274,7 @@ version_range_max = max(sys.version_info[1], 10) + 1
|
||||
|
||||
setup(
|
||||
name="diffusers",
|
||||
version="0.37.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
version="0.36.0.dev0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
|
||||
description="State-of-the-art diffusion in PyTorch and JAX.",
|
||||
long_description=open("README.md", "r", encoding="utf-8").read(),
|
||||
long_description_content_type="text/markdown",
|
||||
|
||||
@@ -23,7 +23,6 @@ from .utils import (
|
||||
is_torchao_available,
|
||||
is_torchsde_available,
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -194,8 +193,6 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -226,7 +223,6 @@ else:
|
||||
"FluxControlNetModel",
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -240,7 +236,6 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
@@ -284,7 +279,6 @@ else:
|
||||
"WanAnimateTransformer3DModel",
|
||||
"WanTransformer3DModel",
|
||||
"WanVACETransformer3DModel",
|
||||
"ZImageControlNetModel",
|
||||
"ZImageTransformer2DModel",
|
||||
"attention_backend",
|
||||
]
|
||||
@@ -358,7 +352,6 @@ else:
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"LCMScheduler",
|
||||
"LTXEulerAncestralRFScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
@@ -423,8 +416,6 @@ else:
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
@@ -471,7 +462,6 @@ else:
|
||||
"CogView4ControlPipeline",
|
||||
"CogView4Pipeline",
|
||||
"ConsisIDPipeline",
|
||||
"Cosmos2_5_PredictBasePipeline",
|
||||
"Cosmos2TextToImagePipeline",
|
||||
"Cosmos2VideoToWorldPipeline",
|
||||
"CosmosTextToWorldPipeline",
|
||||
@@ -494,7 +484,6 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -546,11 +535,7 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
@@ -685,10 +670,7 @@ else:
|
||||
"WuerstchenCombinedPipeline",
|
||||
"WuerstchenDecoderPipeline",
|
||||
"WuerstchenPriorPipeline",
|
||||
"ZImageControlNetInpaintPipeline",
|
||||
"ZImageControlNetPipeline",
|
||||
"ZImageImg2ImgPipeline",
|
||||
"ZImageOmniPipeline",
|
||||
"ZImagePipeline",
|
||||
]
|
||||
)
|
||||
@@ -950,8 +932,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -982,7 +962,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxControlNetModel,
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -996,7 +975,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
@@ -1039,7 +1017,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanAnimateTransformer3DModel,
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
ZImageControlNetModel,
|
||||
ZImageTransformer2DModel,
|
||||
attention_backend,
|
||||
)
|
||||
@@ -1105,7 +1082,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
LTXEulerAncestralRFScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
@@ -1153,8 +1129,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
@@ -1197,7 +1171,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4ControlPipeline,
|
||||
CogView4Pipeline,
|
||||
ConsisIDPipeline,
|
||||
Cosmos2_5_PredictBasePipeline,
|
||||
Cosmos2TextToImagePipeline,
|
||||
Cosmos2VideoToWorldPipeline,
|
||||
CosmosTextToWorldPipeline,
|
||||
@@ -1220,7 +1193,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -1272,11 +1244,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
@@ -1409,10 +1377,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WuerstchenCombinedPipeline,
|
||||
WuerstchenDecoderPipeline,
|
||||
WuerstchenPriorPipeline,
|
||||
ZImageControlNetInpaintPipeline,
|
||||
ZImageControlNetPipeline,
|
||||
ZImageImg2ImgPipeline,
|
||||
ZImageOmniPipeline,
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
|
||||
@@ -25,7 +25,6 @@ if is_torch_available():
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .frequency_decoupled_guidance import FrequencyDecoupledGuidance
|
||||
from .guider_utils import BaseGuidance
|
||||
from .magnitude_aware_guidance import MagnitudeAwareGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, GuiderOutput, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class MagnitudeAwareGuidance(BaseGuidance):
|
||||
"""
|
||||
Magnitude-Aware Mitigation for Boosted Guidance (MAMBO-G): https://huggingface.co/papers/2508.03442
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `10.0`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
alpha (`float`, defaults to `8.0`):
|
||||
The alpha parameter for the magnitude-aware guidance. Higher values cause more aggressive supression of
|
||||
guidance scale when the magnitude of the guidance update is large.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 10.0,
|
||||
alpha: float = 8.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
enabled: bool = True,
|
||||
):
|
||||
super().__init__(start, stop, enabled)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.alpha = alpha
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(self, data: Dict[str, Tuple[torch.Tensor, torch.Tensor]]) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch(data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def prepare_inputs_from_block_state(
|
||||
self, data: "BlockState", input_fields: Dict[str, Union[str, Tuple[str, str]]]
|
||||
) -> List["BlockState"]:
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for tuple_idx, input_prediction in zip(tuple_indices, self._input_predictions):
|
||||
data_batch = self._prepare_batch_from_block_state(input_fields, data, tuple_idx, input_prediction)
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> GuiderOutput:
|
||||
pred = None
|
||||
|
||||
if not self._is_mambo_g_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = mambo_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.alpha,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return GuiderOutput(pred=pred, pred_cond=pred_cond, pred_uncond=pred_uncond)
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_mambo_g_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_mambo_g_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def mambo_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
alpha: float = 8.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
dim = list(range(1, len(pred_cond.shape)))
|
||||
diff = pred_cond - pred_uncond
|
||||
ratio = torch.norm(diff, dim=dim, keepdim=True) / torch.norm(pred_uncond, dim=dim, keepdim=True)
|
||||
guidance_scale_final = (
|
||||
guidance_scale * torch.exp(-alpha * ratio)
|
||||
if use_original_formulation
|
||||
else 1.0 + (guidance_scale - 1.0) * torch.exp(-alpha * ratio)
|
||||
)
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale_final * diff
|
||||
|
||||
return pred
|
||||
@@ -67,7 +67,6 @@ if is_torch_available():
|
||||
"SD3LoraLoaderMixin",
|
||||
"AuraFlowLoraLoaderMixin",
|
||||
"StableDiffusionXLLoraLoaderMixin",
|
||||
"LTX2LoraLoaderMixin",
|
||||
"LTXVideoLoraLoaderMixin",
|
||||
"LoraLoaderMixin",
|
||||
"FluxLoraLoaderMixin",
|
||||
@@ -122,7 +121,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
LoraLoaderMixin,
|
||||
LTX2LoraLoaderMixin,
|
||||
LTXVideoLoraLoaderMixin,
|
||||
Lumina2LoraLoaderMixin,
|
||||
Mochi1LoraLoaderMixin,
|
||||
|
||||
@@ -2140,54 +2140,6 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict, non_diffusers_prefix="diffusion_model"):
|
||||
# Remove the prefix
|
||||
state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{non_diffusers_prefix}.")}
|
||||
converted_state_dict = {k.removeprefix(f"{non_diffusers_prefix}."): v for k, v in state_dict.items()}
|
||||
|
||||
if non_diffusers_prefix == "diffusion_model":
|
||||
rename_dict = {
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
else:
|
||||
rename_dict = {"aggregate_embed": "text_proj_in"}
|
||||
|
||||
# Apply renaming
|
||||
renamed_state_dict = {}
|
||||
for key, value in converted_state_dict.items():
|
||||
new_key = key[:]
|
||||
for old_pattern, new_pattern in rename_dict.items():
|
||||
new_key = new_key.replace(old_pattern, new_pattern)
|
||||
renamed_state_dict[new_key] = value
|
||||
|
||||
# Handle adaln_single -> time_embed and audio_adaln_single -> audio_time_embed
|
||||
final_state_dict = {}
|
||||
for key, value in renamed_state_dict.items():
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
elif key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
final_state_dict[new_key] = value
|
||||
else:
|
||||
final_state_dict[key] = value
|
||||
|
||||
# Add transformer prefix
|
||||
prefix = "transformer" if non_diffusers_prefix == "diffusion_model" else "connectors"
|
||||
final_state_dict = {f"{prefix}.{k}": v for k, v in final_state_dict.items()}
|
||||
|
||||
return final_state_dict
|
||||
|
||||
|
||||
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
|
||||
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if has_diffusion_model:
|
||||
|
||||
@@ -48,7 +48,6 @@ from .lora_conversion_utils import (
|
||||
_convert_non_diffusers_flux2_lora_to_diffusers,
|
||||
_convert_non_diffusers_hidream_lora_to_diffusers,
|
||||
_convert_non_diffusers_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltx2_lora_to_diffusers,
|
||||
_convert_non_diffusers_ltxv_lora_to_diffusers,
|
||||
_convert_non_diffusers_lumina2_lora_to_diffusers,
|
||||
_convert_non_diffusers_qwen_lora_to_diffusers,
|
||||
@@ -75,7 +74,6 @@ logger = logging.get_logger(__name__)
|
||||
TEXT_ENCODER_NAME = "text_encoder"
|
||||
UNET_NAME = "unet"
|
||||
TRANSFORMER_NAME = "transformer"
|
||||
LTX2_CONNECTOR_NAME = "connectors"
|
||||
|
||||
_MODULE_NAME_TO_ATTRIBUTE_MAP_FLUX = {"x_embedder": "in_channels"}
|
||||
|
||||
@@ -214,7 +212,7 @@ class StableDiffusionLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -641,7 +639,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
@@ -1081,7 +1079,7 @@ class SD3LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1377,7 +1375,7 @@ class AuraFlowLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -1659,7 +1657,7 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
|
||||
if not (has_lora_keys or has_norm_keys):
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
transformer_lora_state_dict = {
|
||||
k: state_dict.get(k)
|
||||
@@ -2506,7 +2504,7 @@ class CogVideoXLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2703,7 +2701,7 @@ class Mochi1LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -2906,7 +2904,7 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3013,233 +3011,6 @@ class LTXVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class LTX2LoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`LTX2VideoTransformer3DModel`]. Specific to [`LTX2Pipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer", "connectors"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
connectors_name = LTX2_CONNECTOR_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
final_state_dict = state_dict
|
||||
is_non_diffusers_format = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
has_connector = any(k.startswith("text_embedding_projection.") for k in state_dict)
|
||||
if is_non_diffusers_format:
|
||||
final_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(state_dict)
|
||||
if has_connector:
|
||||
connectors_state_dict = _convert_non_diffusers_ltx2_lora_to_diffusers(
|
||||
state_dict, "text_embedding_projection"
|
||||
)
|
||||
final_state_dict.update(connectors_state_dict)
|
||||
out = (final_state_dict, metadata) if return_lora_metadata else final_state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
adapter_name: Optional[str] = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
transformer_peft_state_dict = {
|
||||
k: v for k, v in state_dict.items() if k.startswith(f"{self.transformer_name}.")
|
||||
}
|
||||
connectors_peft_state_dict = {k: v for k, v in state_dict.items() if k.startswith(f"{self.connectors_name}.")}
|
||||
self.load_lora_into_transformer(
|
||||
transformer_peft_state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
if connectors_peft_state_dict:
|
||||
self.load_lora_into_transformer(
|
||||
connectors_peft_state_dict,
|
||||
transformer=getattr(self, self.connectors_name)
|
||||
if not hasattr(self, "connectors")
|
||||
else self.connectors,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=self.connectors_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
prefix: str = "transformer",
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {prefix}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
prefix=prefix,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: Optional[dict] = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: List[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`SanaTransformer2DModel`]. Specific to [`SanaPipeline`].
|
||||
@@ -3333,7 +3104,7 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3536,7 +3307,7 @@ class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3740,7 +3511,7 @@ class Lumina2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -3940,7 +3711,7 @@ class KandinskyLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4194,7 +3965,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4471,7 +4242,7 @@ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
|
||||
)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
|
||||
if load_into_transformer_2:
|
||||
@@ -4691,7 +4462,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -4894,7 +4665,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5100,7 +4871,7 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5306,7 +5077,7 @@ class ZImageLoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
@@ -5509,7 +5280,7 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
raise ValueError("Invalid LoRA checkpoint.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
|
||||
@@ -63,12 +63,9 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTX2TextConnectors": lambda model_cls, weights: weights,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ from .single_file_utils import (
|
||||
convert_stable_cascade_unet_single_file_to_diffusers,
|
||||
convert_wan_transformer_to_diffusers,
|
||||
convert_wan_vae_to_diffusers,
|
||||
convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
convert_z_image_transformer_checkpoint_to_diffusers,
|
||||
create_controlnet_diffusers_config_from_ldm,
|
||||
create_unet_diffusers_config_from_ldm,
|
||||
@@ -162,7 +161,7 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"QwenImageTransformer2DModel": {
|
||||
"checkpoint_mapping_fn": lambda checkpoint, **kwargs: checkpoint,
|
||||
"checkpoint_mapping_fn": lambda x: x,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"Flux2Transformer2DModel": {
|
||||
@@ -173,18 +172,11 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_z_image_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"ZImageControlNetModel": {
|
||||
"checkpoint_mapping_fn": convert_z_image_controlnet_checkpoint_to_diffusers,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _should_convert_state_dict_to_diffusers(model_state_dict, checkpoint_state_dict):
|
||||
model_state_dict_keys = set(model_state_dict.keys())
|
||||
checkpoint_state_dict_keys = set(checkpoint_state_dict.keys())
|
||||
is_subset = model_state_dict_keys.issubset(checkpoint_state_dict_keys)
|
||||
is_match = model_state_dict_keys == checkpoint_state_dict_keys
|
||||
return not (is_subset and is_match)
|
||||
return not set(model_state_dict.keys()).issubset(set(checkpoint_state_dict.keys()))
|
||||
|
||||
|
||||
def _get_single_file_loadable_mapping_class(cls):
|
||||
|
||||
@@ -120,12 +120,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
||||
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
||||
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
|
||||
"z-image-turbo": [
|
||||
"model.diffusion_model.layers.0.adaLN_modulation.0.weight",
|
||||
"layers.0.adaLN_modulation.0.weight",
|
||||
],
|
||||
"z-image-turbo-controlnet": "control_all_x_embedder.2-1.weight",
|
||||
"z-image-turbo-controlnet-2.x": "control_layers.14.adaLN_modulation.0.weight",
|
||||
"z-image-turbo": "cap_embedder.0.weight",
|
||||
"sana": [
|
||||
"blocks.0.cross_attn.q_linear.weight",
|
||||
"blocks.0.cross_attn.q_linear.bias",
|
||||
@@ -225,9 +220,6 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||
"z-image-turbo": {"pretrained_model_name_or_path": "Tongyi-MAI/Z-Image-Turbo"},
|
||||
"z-image-turbo-controlnet": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union"},
|
||||
"z-image-turbo-controlnet-2.0": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.0"},
|
||||
"z-image-turbo-controlnet-2.1": {"pretrained_model_name_or_path": "hlky/Z-Image-Turbo-Fun-Controlnet-Union-2.1"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -731,7 +723,10 @@ def infer_diffusers_model_type(checkpoint):
|
||||
):
|
||||
model_type = "instruct-pix2pix"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["z-image-turbo"]):
|
||||
elif (
|
||||
CHECKPOINT_KEY_NAMES["z-image-turbo"] in checkpoint
|
||||
and checkpoint[CHECKPOINT_KEY_NAMES["z-image-turbo"]].shape[0] == 2560
|
||||
):
|
||||
model_type = "z-image-turbo"
|
||||
|
||||
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
|
||||
@@ -784,18 +779,6 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet-2.x"] in checkpoint:
|
||||
before_proj_weight = checkpoint.get("control_noise_refiner.0.before_proj.weight", None)
|
||||
if before_proj_weight is None:
|
||||
model_type = "z-image-turbo-controlnet-2.0"
|
||||
elif before_proj_weight is not None and torch.all(before_proj_weight == 0.0):
|
||||
model_type = "z-image-turbo-controlnet-2.0"
|
||||
else:
|
||||
model_type = "z-image-turbo-controlnet-2.1"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["z-image-turbo-controlnet"] in checkpoint:
|
||||
model_type = "z-image-turbo-controlnet"
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3859,7 +3842,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
".attention.k_norm.weight": ".attention.norm_k.weight",
|
||||
".attention.q_norm.weight": ".attention.norm_q.weight",
|
||||
".attention.out.weight": ".attention.to_out.0.weight",
|
||||
"model.diffusion_model.": "",
|
||||
}
|
||||
|
||||
def convert_z_image_fused_attention(key: str, state_dict: dict[str, object]) -> None:
|
||||
@@ -3894,9 +3876,6 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
update_state_dict(converted_state_dict, key, new_key)
|
||||
|
||||
if "norm_final.weight" in converted_state_dict.keys():
|
||||
_ = converted_state_dict.pop("norm_final.weight")
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(converted_state_dict.keys()):
|
||||
@@ -3906,17 +3885,3 @@ def convert_z_image_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_z_image_controlnet_checkpoint_to_diffusers(checkpoint, config, **kwargs):
|
||||
if config["add_control_noise_refiner"] is None:
|
||||
return checkpoint
|
||||
elif config["add_control_noise_refiner"] == "control_noise_refiner":
|
||||
return checkpoint
|
||||
elif config["add_control_noise_refiner"] == "control_layers":
|
||||
converted_state_dict = {
|
||||
key: checkpoint.pop(key) for key in list(checkpoint.keys()) if not key.startswith("control_noise_refiner.")
|
||||
}
|
||||
return converted_state_dict
|
||||
else:
|
||||
raise ValueError("Unknown Z-Image Turbo ControlNet type.")
|
||||
|
||||
@@ -41,8 +41,6 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -68,7 +66,6 @@ if is_torch_available():
|
||||
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]
|
||||
_import_structure["controlnets.controlnet_union"] = ["ControlNetUnionModel"]
|
||||
_import_structure["controlnets.controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
||||
_import_structure["controlnets.controlnet_z_image"] = ["ZImageControlNetModel"]
|
||||
_import_structure["controlnets.multicontrolnet"] = ["MultiControlNetModel"]
|
||||
_import_structure["controlnets.multicontrolnet_union"] = ["MultiControlNetUnionModel"]
|
||||
_import_structure["embeddings"] = ["ImageProjection"]
|
||||
@@ -98,7 +95,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_easyanimate"] = ["EasyAnimateTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -107,7 +103,6 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -157,8 +152,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -188,7 +181,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
SD3MultiControlNetModel,
|
||||
SparseControlNetModel,
|
||||
UNetControlNetXSModel,
|
||||
ZImageControlNetModel,
|
||||
)
|
||||
from .embeddings import ImageProjection
|
||||
from .modeling_utils import ModelMixin
|
||||
@@ -209,7 +201,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
EasyAnimateTransformer3DModel,
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
@@ -219,7 +210,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -90,6 +90,10 @@ class ContextParallelConfig:
|
||||
)
|
||||
if self.ring_degree < 1 or self.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if self.ring_degree > 1 and self.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
if self.rotate_method != "allgather":
|
||||
raise NotImplementedError(
|
||||
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
|
||||
|
||||
@@ -1106,51 +1106,6 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Not implemented yet.
|
||||
def _npu_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
|
||||
|
||||
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
@@ -1177,103 +1132,6 @@ def _all_to_all_single(x: torch.Tensor, group) -> torch.Tensor:
|
||||
return x
|
||||
|
||||
|
||||
def _all_to_all_dim_exchange(x: torch.Tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None) -> torch.Tensor:
|
||||
"""
|
||||
Perform dimension sharding / reassembly across processes using _all_to_all_single.
|
||||
|
||||
This utility reshapes and redistributes tensor `x` across the given process group, across sequence dimension or
|
||||
head dimension flexibly by accepting scatter_idx and gather_idx.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor):
|
||||
Input tensor. Expected shapes:
|
||||
- When scatter_idx=2, gather_idx=1: (batch_size, seq_len_local, num_heads, head_dim)
|
||||
- When scatter_idx=1, gather_idx=2: (batch_size, seq_len, num_heads_local, head_dim)
|
||||
scatter_idx (int) :
|
||||
Dimension along which the tensor is partitioned before all-to-all.
|
||||
gather_idx (int):
|
||||
Dimension along which the output is reassembled after all-to-all.
|
||||
group :
|
||||
Distributed process group for the Ulysses group.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Tensor with globally exchanged dimensions.
|
||||
- For (scatter_idx=2 → gather_idx=1): (batch_size, seq_len, num_heads_local, head_dim)
|
||||
- For (scatter_idx=1 → gather_idx=2): (batch_size, seq_len_local, num_heads, head_dim)
|
||||
"""
|
||||
group_world_size = torch.distributed.get_world_size(group)
|
||||
|
||||
if scatter_idx == 2 and gather_idx == 1:
|
||||
# Used before Ulysses sequence parallel (SP) attention. Scatters the gathers sequence
|
||||
# dimension and scatters head dimension
|
||||
batch_size, seq_len_local, num_heads, head_dim = x.shape
|
||||
seq_len = seq_len_local * group_world_size
|
||||
num_heads_local = num_heads // group_world_size
|
||||
|
||||
# B, S_LOCAL, H, D -> group_world_size, S_LOCAL, B, H_LOCAL, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, seq_len_local, group_world_size, num_heads_local, head_dim)
|
||||
.transpose(0, 2)
|
||||
.contiguous()
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
out = _all_to_all_single(x_temp, group=group)
|
||||
else:
|
||||
out = x_temp
|
||||
# group_world_size, S_LOCAL, B, H_LOCAL, D -> B, S, H_LOCAL, D
|
||||
out = out.reshape(seq_len, batch_size, num_heads_local, head_dim).permute(1, 0, 2, 3).contiguous()
|
||||
out = out.reshape(batch_size, seq_len, num_heads_local, head_dim)
|
||||
return out
|
||||
elif scatter_idx == 1 and gather_idx == 2:
|
||||
# Used after ulysses sequence parallel in unified SP. gathers the head dimension
|
||||
# scatters back the sequence dimension.
|
||||
batch_size, seq_len, num_heads_local, head_dim = x.shape
|
||||
num_heads = num_heads_local * group_world_size
|
||||
seq_len_local = seq_len // group_world_size
|
||||
|
||||
# B, S, H_LOCAL, D -> group_world_size, H_LOCAL, S_LOCAL, B, D
|
||||
x_temp = (
|
||||
x.reshape(batch_size, group_world_size, seq_len_local, num_heads_local, head_dim)
|
||||
.permute(1, 3, 2, 0, 4)
|
||||
.reshape(group_world_size, num_heads_local, seq_len_local, batch_size, head_dim)
|
||||
)
|
||||
|
||||
if group_world_size > 1:
|
||||
output = _all_to_all_single(x_temp, group)
|
||||
else:
|
||||
output = x_temp
|
||||
output = output.reshape(num_heads, seq_len_local, batch_size, head_dim).transpose(0, 2).contiguous()
|
||||
output = output.reshape(batch_size, seq_len_local, num_heads, head_dim)
|
||||
return output
|
||||
else:
|
||||
raise ValueError("Invalid scatter/gather indices for _all_to_all_dim_exchange.")
|
||||
|
||||
|
||||
class SeqAllToAllDim(torch.autograd.Function):
|
||||
"""
|
||||
all_to_all operation for unified sequence parallelism. uses _all_to_all_dim_exchange, see _all_to_all_dim_exchange
|
||||
for more info.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, group, input, scatter_id=2, gather_id=1):
|
||||
ctx.group = group
|
||||
ctx.scatter_id = scatter_id
|
||||
ctx.gather_id = gather_id
|
||||
return _all_to_all_dim_exchange(input, scatter_id, gather_id, group)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_outputs):
|
||||
grad_input = SeqAllToAllDim.apply(
|
||||
ctx.group,
|
||||
grad_outputs,
|
||||
ctx.gather_id, # reversed
|
||||
ctx.scatter_id, # reversed
|
||||
)
|
||||
return (None, grad_input, None, None)
|
||||
|
||||
|
||||
class TemplatedRingAttention(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(
|
||||
@@ -1334,10 +1192,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
out = out.to(torch.float32)
|
||||
lse = lse.to(torch.float32)
|
||||
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1)
|
||||
lse = lse.unsqueeze(-1)
|
||||
if prev_out is not None:
|
||||
out = prev_out - torch.nn.functional.sigmoid(lse - prev_lse) * (prev_out - out)
|
||||
lse = prev_lse - torch.nn.functional.logsigmoid(prev_lse - lse)
|
||||
@@ -1398,7 +1253,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
|
||||
grad_query, grad_key, grad_value = (x.to(grad_out.dtype) for x in (grad_query, grad_key, grad_value))
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@@ -1493,69 +1348,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
x.flatten(0, 1).permute(1, 2, 0, 3).contiguous() for x in (grad_query, grad_key, grad_value)
|
||||
)
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_unified_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor],
|
||||
dropout_p: float,
|
||||
is_causal: bool,
|
||||
scale: Optional[float],
|
||||
enable_gqa: bool,
|
||||
return_lse: bool,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
scatter_idx: int = 2,
|
||||
gather_idx: int = 1,
|
||||
):
|
||||
"""
|
||||
Unified Sequence Parallelism attention combining Ulysses and ring attention. See: https://arxiv.org/abs/2405.07719
|
||||
"""
|
||||
ulysses_mesh = _parallel_config.context_parallel_config._ulysses_mesh
|
||||
ulysses_group = ulysses_mesh.get_group()
|
||||
|
||||
query = SeqAllToAllDim.apply(ulysses_group, query, scatter_idx, gather_idx)
|
||||
key = SeqAllToAllDim.apply(ulysses_group, key, scatter_idx, gather_idx)
|
||||
value = SeqAllToAllDim.apply(ulysses_group, value, scatter_idx, gather_idx)
|
||||
out = TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
context_layer, lse, *_ = out
|
||||
else:
|
||||
context_layer = out
|
||||
# context_layer is of shape (B, S, H_LOCAL, D)
|
||||
output = SeqAllToAllDim.apply(
|
||||
ulysses_group,
|
||||
context_layer,
|
||||
gather_idx,
|
||||
scatter_idx,
|
||||
)
|
||||
if return_lse:
|
||||
# lse is of shape (B, S, H_LOCAL, 1)
|
||||
# Refer to:
|
||||
# https://github.com/huggingface/diffusers/pull/12693#issuecomment-3627519544
|
||||
if is_torch_version("<", "2.9.0"):
|
||||
lse = lse.unsqueeze(-1) # (B, S, H_LOCAL, 1)
|
||||
lse = SeqAllToAllDim.apply(ulysses_group, lse, gather_idx, scatter_idx)
|
||||
lse = lse.squeeze(-1)
|
||||
return (output, lse)
|
||||
return output
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
def _templated_context_parallel_attention(
|
||||
@@ -1581,25 +1374,7 @@ def _templated_context_parallel_attention(
|
||||
raise ValueError("GQA is not yet supported for templated attention.")
|
||||
|
||||
# TODO: add support for unified attention with ring/ulysses degree both being > 1
|
||||
if (
|
||||
_parallel_config.context_parallel_config.ring_degree > 1
|
||||
and _parallel_config.context_parallel_config.ulysses_degree > 1
|
||||
):
|
||||
return _templated_unified_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
enable_gqa,
|
||||
return_lse,
|
||||
forward_op,
|
||||
backward_op,
|
||||
_parallel_config,
|
||||
)
|
||||
elif _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
if _parallel_config.context_parallel_config.ring_degree > 1:
|
||||
return TemplatedRingAttention.apply(
|
||||
query,
|
||||
key,
|
||||
@@ -1645,7 +1420,6 @@ def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1653,9 +1427,6 @@ def _flash_attention(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -1698,7 +1469,6 @@ def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1706,9 +1476,6 @@ def _flash_attention_hub(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1845,15 +1612,11 @@ def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -1873,7 +1636,6 @@ def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
@@ -1884,8 +1646,6 @@ def _flash_attention_3_hub(
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
@@ -2025,16 +1785,12 @@ def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
@@ -2125,43 +1881,6 @@ def _native_flex_attention(
|
||||
return out
|
||||
|
||||
|
||||
def _prepare_additive_attn_mask(
|
||||
attn_mask: torch.Tensor, target_dtype: torch.dtype, reshape_4d: bool = True
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert a 2D attention mask to an additive mask, optionally reshaping to 4D for SDPA.
|
||||
|
||||
This helper is used by both native SDPA and xformers backends to handle both boolean and additive masks.
|
||||
|
||||
Args:
|
||||
attn_mask: 2D tensor [batch_size, seq_len_k]
|
||||
- Boolean: True means attend, False means mask out
|
||||
- Additive: 0.0 means attend, -inf means mask out
|
||||
target_dtype: The dtype to convert the mask to (usually query.dtype)
|
||||
reshape_4d: If True, reshape from [batch_size, seq_len_k] to [batch_size, 1, 1, seq_len_k] for broadcasting
|
||||
|
||||
Returns:
|
||||
Additive mask tensor where 0.0 means attend and -inf means mask out. Shape is [batch_size, seq_len_k] if
|
||||
reshape_4d=False, or [batch_size, 1, 1, seq_len_k] if reshape_4d=True.
|
||||
"""
|
||||
# Check if the mask is boolean or already additive
|
||||
if attn_mask.dtype == torch.bool:
|
||||
# Convert boolean to additive: True -> 0.0, False -> -inf
|
||||
attn_mask = torch.where(attn_mask, 0.0, float("-inf"))
|
||||
# Convert to target dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
else:
|
||||
# Already additive mask - just ensure correct dtype
|
||||
attn_mask = attn_mask.to(dtype=target_dtype)
|
||||
|
||||
# Optionally reshape to 4D for broadcasting in attention mechanisms
|
||||
if reshape_4d:
|
||||
batch_size, seq_len_k = attn_mask.shape
|
||||
attn_mask = attn_mask.view(batch_size, 1, 1, seq_len_k)
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.NATIVE,
|
||||
constraints=[_check_device, _check_shape],
|
||||
@@ -2181,19 +1900,6 @@ def _native_attention(
|
||||
) -> torch.Tensor:
|
||||
if return_lse:
|
||||
raise ValueError("Native attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
# Reshape 2D mask to 4D for SDPA
|
||||
# SDPA accepts both boolean masks (torch.bool) and additive masks (float)
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
# Just reshape [batch_size, seq_len_k] -> [batch_size, 1, 1, seq_len_k]
|
||||
# SDPA handles both boolean and additive masks correctly
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1)
|
||||
|
||||
if _parallel_config is None:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
@@ -2322,7 +2028,6 @@ def _native_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -2330,9 +2035,6 @@ def _native_flash_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
lse = None
|
||||
if _parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2406,52 +2108,34 @@ def _native_math_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._NATIVE_NPU,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
if _parallel_config is None:
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
None,
|
||||
return_lse,
|
||||
forward_op=_npu_attention_forward_op,
|
||||
backward_op=_npu_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(1), # num_heads
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
return out
|
||||
|
||||
|
||||
@@ -2464,13 +2148,10 @@ def _native_xla_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for XLA attention")
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2494,14 +2175,11 @@ def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
@@ -2545,14 +2223,11 @@ def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2634,14 +2309,11 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2661,14 +2333,11 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2688,14 +2357,11 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2715,14 +2381,11 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2760,34 +2423,10 @@ def _xformers_attention(
|
||||
attn_mask = xops.LowerTriangularMask()
|
||||
elif attn_mask is not None:
|
||||
if attn_mask.ndim == 2:
|
||||
# Convert 2D mask to 4D for xformers
|
||||
# Mask can be boolean (True=attend, False=mask) or additive (0.0=attend, -inf=mask)
|
||||
# xformers requires 4D additive masks [batch, heads, seq_q, seq_k]
|
||||
# Need memory alignment - create larger tensor and slice for alignment
|
||||
original_seq_len = attn_mask.size(1)
|
||||
aligned_seq_len = ((original_seq_len + 7) // 8) * 8 # Round up to multiple of 8
|
||||
|
||||
# Create aligned 4D tensor and slice to ensure proper memory layout
|
||||
aligned_mask = torch.zeros(
|
||||
(batch_size, num_heads_q, seq_len_q, aligned_seq_len),
|
||||
dtype=query.dtype,
|
||||
device=query.device,
|
||||
)
|
||||
# Convert to 4D additive mask (handles both boolean and additive inputs)
|
||||
mask_additive = _prepare_additive_attn_mask(
|
||||
attn_mask, target_dtype=query.dtype
|
||||
) # [batch, 1, 1, seq_len_k]
|
||||
# Broadcast to [batch, heads, seq_q, seq_len_k]
|
||||
aligned_mask[:, :, :, :original_seq_len] = mask_additive
|
||||
# Mask out the padding (already -inf from zeros -> where with default)
|
||||
aligned_mask[:, :, :, original_seq_len:] = float("-inf")
|
||||
|
||||
# Slice to actual size with proper alignment
|
||||
attn_mask = aligned_mask[:, :, :, :seq_len_kv]
|
||||
attn_mask = attn_mask.view(attn_mask.size(0), 1, attn_mask.size(1), 1)
|
||||
elif attn_mask.ndim != 4:
|
||||
raise ValueError("Only 2D and 4D attention masks are supported for xformers attention.")
|
||||
elif attn_mask.ndim == 4:
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
attn_mask = attn_mask.expand(batch_size, num_heads_q, seq_len_q, seq_len_kv).type_as(query)
|
||||
|
||||
if enable_gqa:
|
||||
if num_heads_q % num_heads_kv != 0:
|
||||
|
||||
@@ -10,8 +10,6 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
@@ -102,14 +102,14 @@ def get_block(
|
||||
attention_head_dim: int,
|
||||
norm_type: str,
|
||||
act_fn: str,
|
||||
qkv_multiscales: Tuple[int, ...] = (),
|
||||
qkv_mutliscales: Tuple[int, ...] = (),
|
||||
):
|
||||
if block_type == "ResBlock":
|
||||
block = ResBlock(in_channels, out_channels, norm_type, act_fn)
|
||||
|
||||
elif block_type == "EfficientViTBlock":
|
||||
block = EfficientViTBlock(
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_multiscales
|
||||
in_channels, attention_head_dim=attention_head_dim, norm_type=norm_type, qkv_multiscales=qkv_mutliscales
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -247,7 +247,7 @@ class Encoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type="rms_norm",
|
||||
act_fn="silu",
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
down_block_list.append(block)
|
||||
|
||||
@@ -339,7 +339,7 @@ class Decoder(nn.Module):
|
||||
attention_head_dim=attention_head_dim,
|
||||
norm_type=norm_type[i],
|
||||
act_fn=act_fn[i],
|
||||
qkv_multiscales=qkv_multiscales[i],
|
||||
qkv_mutliscales=qkv_multiscales[i],
|
||||
)
|
||||
up_block_list.append(block)
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -410,7 +410,7 @@ class HunyuanImageDecoder2D(nn.Module):
|
||||
return h
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
class AutoencoderKLHunyuanImage(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
r"""
|
||||
A VAE model for 2D images with spatial tiling support.
|
||||
|
||||
@@ -486,6 +486,27 @@ class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromO
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor):
|
||||
|
||||
batch_size, num_channels, height, width = x.shape
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -584,7 +584,7 @@ class HunyuanImageRefinerDecoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanImageRefiner(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanImage-2.1 Refiner.
|
||||
@@ -685,6 +685,27 @@ class AutoencoderKLHunyuanImageRefiner(ModelMixin, AutoencoderMixin, ConfigMixin
|
||||
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..activations import get_activation
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
from .vae import DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -625,7 +625,7 @@ class HunyuanVideo15Decoder3D(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
class AutoencoderKLHunyuanVideo15(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used for
|
||||
HunyuanVideo-1.5.
|
||||
@@ -723,6 +723,27 @@ class AutoencoderKLHunyuanVideo15(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
self.tile_latent_min_width = tile_latent_min_width or self.tile_latent_min_width
|
||||
self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor
|
||||
|
||||
def disable_tiling(self) -> None:
|
||||
r"""
|
||||
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_tiling = False
|
||||
|
||||
def enable_slicing(self) -> None:
|
||||
r"""
|
||||
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
||||
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
||||
"""
|
||||
self.use_slicing = True
|
||||
|
||||
def disable_slicing(self) -> None:
|
||||
r"""
|
||||
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
|
||||
decoding in one step.
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, _, height, width = x.shape
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,804 +0,0 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
self.up.insert(0, stage)
|
||||
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
if self.give_pre_end:
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE for encoding and decoding audio latent representations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
@@ -41,11 +41,9 @@ class CacheMixin:
|
||||
Enable caching techniques on the model.
|
||||
|
||||
Args:
|
||||
config (`Union[PyramidAttentionBroadcastConfig, FasterCacheConfig, FirstBlockCacheConfig]`):
|
||||
config (`Union[PyramidAttentionBroadcastConfig]`):
|
||||
The configuration for applying the caching technique. Currently supported caching techniques are:
|
||||
- [`~hooks.PyramidAttentionBroadcastConfig`]
|
||||
- [`~hooks.FasterCacheConfig`]
|
||||
- [`~hooks.FirstBlockCacheConfig`]
|
||||
|
||||
Example:
|
||||
|
||||
|
||||
115
src/diffusers/models/controlnet.py
Normal file
115
src/diffusers/models/controlnet.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetOutput(ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlNetModel(ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
mid_block_type=mid_block_type,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
class_embed_type=class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=num_class_embeds,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
70
src/diffusers/models/controlnet_flux.py
Normal file
70
src/diffusers/models/controlnet_flux.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxControlNetOutput(FluxControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxControlNetModel(FluxControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
guidance_embeds=guidance_embeds,
|
||||
axes_dims_rope=axes_dims_rope,
|
||||
num_mode=num_mode,
|
||||
conditioning_embedding_channels=conditioning_embedding_channels,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(FluxMultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
68
src/diffusers/models/controlnet_sd3.py
Normal file
68
src/diffusers/models/controlnet_sd3.py
Normal file
@@ -0,0 +1,68 @@
|
||||
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3ControlNetOutput(SD3ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SD3ControlNetModel(SD3ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
sample_size=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
out_channels=out_channels,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
extra_conditioning_channels=extra_conditioning_channels,
|
||||
)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(SD3MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
116
src/diffusers/models/controlnet_sparsectrl.py
Normal file
116
src/diffusers/models/controlnet_sparsectrl.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
SparseControlNetConditioningEmbedding,
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SparseControlNetOutput(SparseControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
|
||||
deprecate(
|
||||
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetModel(SparseControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
motion_max_seq_length=motion_max_seq_length,
|
||||
motion_num_attention_heads=motion_num_attention_heads,
|
||||
concat_conditioning_mask=concat_conditioning_mask,
|
||||
use_simplified_condition_embedding=use_simplified_condition_embedding,
|
||||
)
|
||||
@@ -19,7 +19,6 @@ if is_torch_available():
|
||||
)
|
||||
from .controlnet_union import ControlNetUnionModel
|
||||
from .controlnet_xs import ControlNetXSAdapter, ControlNetXSOutput, UNetControlNetXSModel
|
||||
from .controlnet_z_image import ZImageControlNetModel
|
||||
from .multicontrolnet import MultiControlNetModel
|
||||
from .multicontrolnet_union import MultiControlNetUnionModel
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ..attention import AttentionMixin
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..controlnets.controlnet import zero_module
|
||||
@@ -31,7 +31,6 @@ from ..transformers.transformer_qwenimage import (
|
||||
QwenImageTransformerBlock,
|
||||
QwenTimestepProjEmbeddings,
|
||||
RMSNorm,
|
||||
compute_text_seq_len_from_mask,
|
||||
)
|
||||
|
||||
|
||||
@@ -137,7 +136,7 @@ class QwenImageControlNetModel(
|
||||
return_dict: bool = True,
|
||||
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
||||
"""
|
||||
The [`QwenImageControlNetModel`] forward method.
|
||||
The [`FluxTransformer2DModel`] forward method.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
||||
@@ -148,39 +147,24 @@ class QwenImageControlNetModel(
|
||||
The scale factor for ControlNet outputs.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
||||
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
||||
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`, *optional*):
|
||||
Mask for the encoder hidden states. Expected to have 1.0 for valid tokens and 0.0 for padding tokens.
|
||||
Used in the attention processor to prevent attending to padding tokens. The mask can have any pattern
|
||||
(not just contiguous valid tokens followed by padding) since it's applied element-wise in attention.
|
||||
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
||||
from the embeddings of input conditions.
|
||||
timestep ( `torch.LongTensor`):
|
||||
Used to indicate denoising step.
|
||||
img_shapes (`List[Tuple[int, int, int]]`, *optional*):
|
||||
Image shapes for RoPE computation.
|
||||
txt_seq_lens (`List[int]`, *optional*):
|
||||
**Deprecated**. Not needed anymore, we use `encoder_hidden_states` instead to infer text sequence
|
||||
length.
|
||||
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
|
||||
A list of tensors that if specified are added to the residuals of transformer blocks.
|
||||
joint_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
|
||||
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
If `return_dict` is True, a [`~models.controlnet.ControlNetOutput`] is returned, otherwise a `tuple` where
|
||||
the first element is the controlnet block samples.
|
||||
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
||||
`tuple` where the first element is the sample tensor.
|
||||
"""
|
||||
# Handle deprecated txt_seq_lens parameter
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageControlNetModel.forward()` is deprecated and will be removed in "
|
||||
"version 0.39.0. The text sequence length is now automatically inferred from `encoder_hidden_states` "
|
||||
"and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
|
||||
if joint_attention_kwargs is not None:
|
||||
joint_attention_kwargs = joint_attention_kwargs.copy()
|
||||
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
||||
@@ -202,47 +186,32 @@ class QwenImageControlNetModel(
|
||||
|
||||
temb = self.time_text_embed(timestep, hidden_states)
|
||||
|
||||
# Use the encoder_hidden_states sequence length for RoPE computation and normalize mask
|
||||
text_seq_len, _, encoder_hidden_states_mask = compute_text_seq_len_from_mask(
|
||||
encoder_hidden_states, encoder_hidden_states_mask
|
||||
)
|
||||
|
||||
image_rotary_emb = self.pos_embed(img_shapes, max_txt_seq_len=text_seq_len, device=hidden_states.device)
|
||||
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
||||
|
||||
timestep = timestep.to(hidden_states.dtype)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Construct joint attention mask once to avoid reconstructing in every block
|
||||
block_attention_kwargs = joint_attention_kwargs.copy() if joint_attention_kwargs is not None else {}
|
||||
if encoder_hidden_states_mask is not None:
|
||||
# Build joint mask: [text_mask, all_ones_for_image]
|
||||
batch_size, image_seq_len = hidden_states.shape[:2]
|
||||
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
|
||||
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
|
||||
block_attention_kwargs["attention_mask"] = joint_attention_mask
|
||||
|
||||
block_samples = ()
|
||||
for block in self.transformer_blocks:
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
None, # Don't pass encoder_hidden_states_mask (using attention_mask instead)
|
||||
encoder_hidden_states_mask,
|
||||
temb,
|
||||
image_rotary_emb,
|
||||
block_attention_kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_hidden_states_mask=None, # Don't pass (using attention_mask instead)
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
temb=temb,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
joint_attention_kwargs=block_attention_kwargs,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
block_samples = block_samples + (hidden_states,)
|
||||
|
||||
@@ -298,15 +267,6 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[QwenImageControlNetOutput, Tuple]:
|
||||
if txt_seq_lens is not None:
|
||||
deprecate(
|
||||
"txt_seq_lens",
|
||||
"0.39.0",
|
||||
"Passing `txt_seq_lens` to `QwenImageMultiControlNetModel.forward()` is deprecated and will be "
|
||||
"removed in version 0.39.0. The text sequence length is now automatically inferred from "
|
||||
"`encoder_hidden_states` and `encoder_hidden_states_mask`.",
|
||||
standard_warn=False,
|
||||
)
|
||||
# ControlNet-Union with multiple conditions
|
||||
# only load one ControlNet for saving memories
|
||||
if len(self.nets) == 1:
|
||||
@@ -321,6 +281,7 @@ class QwenImageMultiControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, F
|
||||
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||
timestep=timestep,
|
||||
img_shapes=img_shapes,
|
||||
txt_seq_lens=txt_seq_lens,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
@@ -1,845 +0,0 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...loaders.single_file_model import FromOriginalModelMixin
|
||||
from ...models.attention_processor import Attention
|
||||
from ...models.normalization import RMSNorm
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..controlnets.controlnet import zero_module
|
||||
from ..modeling_utils import ModelMixin
|
||||
|
||||
|
||||
ADALN_EMBED_DIM = 256
|
||||
SEQ_MULTI_OF = 32
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.TimestepEmbedder
|
||||
class TimestepEmbedder(nn.Module):
|
||||
def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
|
||||
super().__init__()
|
||||
if mid_size is None:
|
||||
mid_size = out_size
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(frequency_embedding_size, mid_size, bias=True),
|
||||
nn.SiLU(),
|
||||
nn.Linear(mid_size, out_size, bias=True),
|
||||
)
|
||||
|
||||
self.frequency_embedding_size = frequency_embedding_size
|
||||
|
||||
@staticmethod
|
||||
def timestep_embedding(t, dim, max_period=10000):
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half
|
||||
)
|
||||
args = t[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
def forward(self, t):
|
||||
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
|
||||
weight_dtype = self.mlp[0].weight.dtype
|
||||
compute_dtype = getattr(self.mlp[0], "compute_dtype", None)
|
||||
if weight_dtype.is_floating_point:
|
||||
t_freq = t_freq.to(weight_dtype)
|
||||
elif compute_dtype is not None:
|
||||
t_freq = t_freq.to(compute_dtype)
|
||||
t_emb = self.mlp(t_freq)
|
||||
return t_emb
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZSingleStreamAttnProcessor
|
||||
class ZSingleStreamAttnProcessor:
|
||||
"""
|
||||
Processor for Z-Image single stream attention that adapts the existing Attention class to match the behavior of the
|
||||
original Z-ImageAttention module.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"ZSingleStreamAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
freqs_cis: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(hidden_states)
|
||||
value = attn.to_v(hidden_states)
|
||||
|
||||
query = query.unflatten(-1, (attn.heads, -1))
|
||||
key = key.unflatten(-1, (attn.heads, -1))
|
||||
value = value.unflatten(-1, (attn.heads, -1))
|
||||
|
||||
# Apply Norms
|
||||
if attn.norm_q is not None:
|
||||
query = attn.norm_q(query)
|
||||
if attn.norm_k is not None:
|
||||
key = attn.norm_k(key)
|
||||
|
||||
# Apply RoPE
|
||||
def apply_rotary_emb(x_in: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
with torch.amp.autocast("cuda", enabled=False):
|
||||
x = torch.view_as_complex(x_in.float().reshape(*x_in.shape[:-1], -1, 2))
|
||||
freqs_cis = freqs_cis.unsqueeze(2)
|
||||
x_out = torch.view_as_real(x * freqs_cis).flatten(3)
|
||||
return x_out.type_as(x_in) # todo
|
||||
|
||||
if freqs_cis is not None:
|
||||
query = apply_rotary_emb(query, freqs_cis)
|
||||
key = apply_rotary_emb(key, freqs_cis)
|
||||
|
||||
# Cast to correct dtype
|
||||
dtype = query.dtype
|
||||
query, key = query.to(dtype), key.to(dtype)
|
||||
|
||||
# From [batch, seq_len] to [batch, 1, 1, seq_len] -> broadcast to [batch, heads, seq_len, seq_len]
|
||||
if attention_mask is not None and attention_mask.ndim == 2:
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Compute joint attention
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape back
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.to(dtype)
|
||||
|
||||
output = attn.to_out[0](hidden_states)
|
||||
if len(attn.to_out) > 1: # dropout
|
||||
output = attn.to_out[1](output)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.FeedForward
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
||||
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
||||
|
||||
def _forward_silu_gating(self, x1, x3):
|
||||
return F.silu(x1) * x3
|
||||
|
||||
def forward(self, x):
|
||||
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
|
||||
def select_per_token(
|
||||
value_noisy: torch.Tensor,
|
||||
value_clean: torch.Tensor,
|
||||
noise_mask: torch.Tensor,
|
||||
seq_len: int,
|
||||
) -> torch.Tensor:
|
||||
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
|
||||
return torch.where(
|
||||
noise_mask_expanded == 1,
|
||||
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
|
||||
)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
|
||||
class ZImageTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
noise_mask: Optional[torch.Tensor] = None,
|
||||
adaln_noisy: Optional[torch.Tensor] = None,
|
||||
adaln_clean: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if self.modulation:
|
||||
seq_len = x.shape[1]
|
||||
|
||||
if noise_mask is not None:
|
||||
# Per-token modulation: different modulation for noisy/clean tokens
|
||||
mod_noisy = self.adaLN_modulation(adaln_noisy)
|
||||
mod_clean = self.adaLN_modulation(adaln_clean)
|
||||
|
||||
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
|
||||
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
|
||||
|
||||
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
|
||||
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
|
||||
|
||||
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
|
||||
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
|
||||
|
||||
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
|
||||
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
|
||||
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
|
||||
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
|
||||
else:
|
||||
# Global modulation: same modulation for all tokens (avoid double select)
|
||||
mod = self.adaLN_modulation(adaln_input)
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(x) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
x = x + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(x) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(self.attention_norm1(x), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
x = x + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.RopeEmbedder
|
||||
class RopeEmbedder:
|
||||
def __init__(
|
||||
self,
|
||||
theta: float = 256.0,
|
||||
axes_dims: List[int] = (16, 56, 56),
|
||||
axes_lens: List[int] = (64, 128, 128),
|
||||
):
|
||||
self.theta = theta
|
||||
self.axes_dims = axes_dims
|
||||
self.axes_lens = axes_lens
|
||||
assert len(axes_dims) == len(axes_lens), "axes_dims and axes_lens must have the same length"
|
||||
self.freqs_cis = None
|
||||
|
||||
@staticmethod
|
||||
def precompute_freqs_cis(dim: List[int], end: List[int], theta: float = 256.0):
|
||||
with torch.device("cpu"):
|
||||
freqs_cis = []
|
||||
for i, (d, e) in enumerate(zip(dim, end)):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, d, 2, dtype=torch.float64, device="cpu") / d))
|
||||
timestep = torch.arange(e, device=freqs.device, dtype=torch.float64)
|
||||
freqs = torch.outer(timestep, freqs).float()
|
||||
freqs_cis_i = torch.polar(torch.ones_like(freqs), freqs).to(torch.complex64) # complex64
|
||||
freqs_cis.append(freqs_cis_i)
|
||||
|
||||
return freqs_cis
|
||||
|
||||
def __call__(self, ids: torch.Tensor):
|
||||
assert ids.ndim == 2
|
||||
assert ids.shape[-1] == len(self.axes_dims)
|
||||
device = ids.device
|
||||
|
||||
if self.freqs_cis is None:
|
||||
self.freqs_cis = self.precompute_freqs_cis(self.axes_dims, self.axes_lens, theta=self.theta)
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
else:
|
||||
# Ensure freqs_cis are on the same device as ids
|
||||
if self.freqs_cis[0].device != device:
|
||||
self.freqs_cis = [freqs_cis.to(device) for freqs_cis in self.freqs_cis]
|
||||
|
||||
result = []
|
||||
for i in range(len(self.axes_dims)):
|
||||
index = ids[:, i]
|
||||
result.append(self.freqs_cis[i][index])
|
||||
return torch.cat(result, dim=-1)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class ZImageControlTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
layer_id: int,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
n_kv_heads: int,
|
||||
norm_eps: float,
|
||||
qk_norm: bool,
|
||||
modulation=True,
|
||||
block_id=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
# Refactored to use diffusers Attention with custom processor
|
||||
# Original Z-Image params: dim, n_heads, n_kv_heads, qk_norm
|
||||
self.attention = Attention(
|
||||
query_dim=dim,
|
||||
cross_attention_dim=None,
|
||||
dim_head=dim // n_heads,
|
||||
heads=n_heads,
|
||||
qk_norm="rms_norm" if qk_norm else None,
|
||||
eps=1e-5,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
processor=ZSingleStreamAttnProcessor(),
|
||||
)
|
||||
|
||||
self.feed_forward = FeedForward(dim=dim, hidden_dim=int(dim / 3 * 8))
|
||||
self.layer_id = layer_id
|
||||
|
||||
self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
|
||||
|
||||
self.modulation = modulation
|
||||
if modulation:
|
||||
self.adaLN_modulation = nn.Sequential(nn.Linear(min(dim, ADALN_EMBED_DIM), 4 * dim, bias=True))
|
||||
|
||||
# Control variant start
|
||||
self.block_id = block_id
|
||||
if block_id == 0:
|
||||
self.before_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||
self.after_proj = zero_module(nn.Linear(self.dim, self.dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
c: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
adaln_input: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# Control
|
||||
if self.block_id == 0:
|
||||
c = self.before_proj(c) + x
|
||||
all_c = []
|
||||
else:
|
||||
all_c = list(torch.unbind(c))
|
||||
c = all_c.pop(-1)
|
||||
|
||||
# Compared to `ZImageTransformerBlock` x -> c
|
||||
if self.modulation:
|
||||
assert adaln_input is not None
|
||||
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
|
||||
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
|
||||
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
|
||||
|
||||
# Attention block
|
||||
attn_out = self.attention(
|
||||
self.attention_norm1(c) * scale_msa, attention_mask=attn_mask, freqs_cis=freqs_cis
|
||||
)
|
||||
c = c + gate_msa * self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
c = c + gate_mlp * self.ffn_norm2(self.feed_forward(self.ffn_norm1(c) * scale_mlp))
|
||||
else:
|
||||
# Attention block
|
||||
attn_out = self.attention(self.attention_norm1(c), attention_mask=attn_mask, freqs_cis=freqs_cis)
|
||||
c = c + self.attention_norm2(attn_out)
|
||||
|
||||
# FFN block
|
||||
c = c + self.ffn_norm2(self.feed_forward(self.ffn_norm1(c)))
|
||||
|
||||
# Control
|
||||
c_skip = self.after_proj(c)
|
||||
all_c += [c_skip, c]
|
||||
c = torch.stack(all_c)
|
||||
return c
|
||||
|
||||
|
||||
class ZImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
control_layers_places: List[int] = None,
|
||||
control_refiner_layers_places: List[int] = None,
|
||||
control_in_dim=None,
|
||||
add_control_noise_refiner: Optional[Literal["control_layers", "control_noise_refiner"]] = None,
|
||||
all_patch_size=(2,),
|
||||
all_f_patch_size=(1,),
|
||||
dim=3840,
|
||||
n_refiner_layers=2,
|
||||
n_heads=30,
|
||||
n_kv_heads=30,
|
||||
norm_eps=1e-5,
|
||||
qk_norm=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.control_layers_places = control_layers_places
|
||||
self.control_in_dim = control_in_dim
|
||||
self.control_refiner_layers_places = control_refiner_layers_places
|
||||
self.add_control_noise_refiner = add_control_noise_refiner
|
||||
|
||||
assert 0 in self.control_layers_places
|
||||
|
||||
# control blocks
|
||||
self.control_layers = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(i, dim, n_heads, n_kv_heads, norm_eps, qk_norm, block_id=i)
|
||||
for i in self.control_layers_places
|
||||
]
|
||||
)
|
||||
|
||||
# control patch embeddings
|
||||
all_x_embedder = {}
|
||||
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
|
||||
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True)
|
||||
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
|
||||
|
||||
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
|
||||
if self.add_control_noise_refiner == "control_layers":
|
||||
self.control_noise_refiner = None
|
||||
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageControlTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
block_id=layer_id,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.control_noise_refiner = nn.ModuleList(
|
||||
[
|
||||
ZImageTransformerBlock(
|
||||
1000 + layer_id,
|
||||
dim,
|
||||
n_heads,
|
||||
n_kv_heads,
|
||||
norm_eps,
|
||||
qk_norm,
|
||||
modulation=True,
|
||||
)
|
||||
for layer_id in range(n_refiner_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.t_scale: Optional[float] = None
|
||||
self.t_embedder: Optional[TimestepEmbedder] = None
|
||||
self.all_x_embedder: Optional[nn.ModuleDict] = None
|
||||
self.cap_embedder: Optional[nn.Sequential] = None
|
||||
self.rope_embedder: Optional[RopeEmbedder] = None
|
||||
self.noise_refiner: Optional[nn.ModuleList] = None
|
||||
self.context_refiner: Optional[nn.ModuleList] = None
|
||||
self.x_pad_token: Optional[nn.Parameter] = None
|
||||
self.cap_pad_token: Optional[nn.Parameter] = None
|
||||
|
||||
@classmethod
|
||||
def from_transformer(cls, controlnet, transformer):
|
||||
controlnet.t_scale = transformer.t_scale
|
||||
controlnet.t_embedder = transformer.t_embedder
|
||||
controlnet.all_x_embedder = transformer.all_x_embedder
|
||||
controlnet.cap_embedder = transformer.cap_embedder
|
||||
controlnet.rope_embedder = transformer.rope_embedder
|
||||
controlnet.noise_refiner = transformer.noise_refiner
|
||||
controlnet.context_refiner = transformer.context_refiner
|
||||
controlnet.x_pad_token = transformer.x_pad_token
|
||||
controlnet.cap_pad_token = transformer.cap_pad_token
|
||||
return controlnet
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.create_coordinate_grid
|
||||
def create_coordinate_grid(size, start=None, device=None):
|
||||
if start is None:
|
||||
start = (0 for _ in size)
|
||||
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
|
||||
grids = torch.meshgrid(axes, indexing="ij")
|
||||
return torch.stack(grids, dim=-1)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
|
||||
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
|
||||
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
|
||||
pH, pW, pF = patch_size, patch_size, f_patch_size
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
|
||||
def _pad_with_ids(
|
||||
self,
|
||||
feat: torch.Tensor,
|
||||
pos_grid_size: Tuple,
|
||||
pos_start: Tuple,
|
||||
device: torch.device,
|
||||
noise_mask_val: Optional[int] = None,
|
||||
):
|
||||
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
|
||||
ori_len = len(feat)
|
||||
pad_len = (-ori_len) % SEQ_MULTI_OF
|
||||
total_len = ori_len + pad_len
|
||||
|
||||
# Pos IDs
|
||||
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
|
||||
if pad_len > 0:
|
||||
pad_pos_ids = (
|
||||
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
|
||||
.flatten(0, 2)
|
||||
.repeat(pad_len, 1)
|
||||
)
|
||||
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
|
||||
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
|
||||
pad_mask = torch.cat(
|
||||
[
|
||||
torch.zeros(ori_len, dtype=torch.bool, device=device),
|
||||
torch.ones(pad_len, dtype=torch.bool, device=device),
|
||||
]
|
||||
)
|
||||
else:
|
||||
pos_ids = ori_pos_ids
|
||||
padded_feat = feat
|
||||
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
|
||||
|
||||
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
|
||||
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
|
||||
def patchify_and_embed(
|
||||
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
|
||||
):
|
||||
"""Patchify for basic mode: single image per batch item."""
|
||||
device = all_image[0].device
|
||||
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
|
||||
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
|
||||
|
||||
for image, cap_feat in zip(all_image, all_cap_feats):
|
||||
# Caption
|
||||
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
|
||||
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
|
||||
)
|
||||
all_cap_out.append(cap_out)
|
||||
all_cap_pos_ids.append(cap_pos_ids)
|
||||
all_cap_pad_mask.append(cap_pad_mask)
|
||||
|
||||
# Image
|
||||
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
|
||||
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
|
||||
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
|
||||
)
|
||||
all_img_out.append(img_out)
|
||||
all_img_size.append(size)
|
||||
all_img_pos_ids.append(img_pos_ids)
|
||||
all_img_pad_mask.append(img_pad_mask)
|
||||
|
||||
return (
|
||||
all_img_out,
|
||||
all_cap_out,
|
||||
all_img_size,
|
||||
all_img_pos_ids,
|
||||
all_cap_pos_ids,
|
||||
all_img_pad_mask,
|
||||
all_cap_pad_mask,
|
||||
)
|
||||
|
||||
def patchify(
|
||||
self,
|
||||
all_image: List[torch.Tensor],
|
||||
patch_size: int,
|
||||
f_patch_size: int,
|
||||
):
|
||||
pH = pW = patch_size
|
||||
pF = f_patch_size
|
||||
all_image_out = []
|
||||
|
||||
for i, image in enumerate(all_image):
|
||||
### Process Image
|
||||
C, F, H, W = image.size()
|
||||
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
|
||||
|
||||
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
|
||||
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
|
||||
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
|
||||
|
||||
image_ori_len = len(image)
|
||||
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
|
||||
|
||||
# padded feature
|
||||
image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0)
|
||||
all_image_out.append(image_padded_feat)
|
||||
|
||||
return all_image_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: List[torch.Tensor],
|
||||
t,
|
||||
cap_feats: List[torch.Tensor],
|
||||
control_context: List[torch.Tensor],
|
||||
conditioning_scale: float = 1.0,
|
||||
patch_size=2,
|
||||
f_patch_size=1,
|
||||
):
|
||||
if (
|
||||
self.t_scale is None
|
||||
or self.t_embedder is None
|
||||
or self.all_x_embedder is None
|
||||
or self.cap_embedder is None
|
||||
or self.rope_embedder is None
|
||||
or self.noise_refiner is None
|
||||
or self.context_refiner is None
|
||||
or self.x_pad_token is None
|
||||
or self.cap_pad_token is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Required modules are `None`, use `from_transformer` to share required modules from `transformer`."
|
||||
)
|
||||
|
||||
assert patch_size in self.config.all_patch_size
|
||||
assert f_patch_size in self.config.all_f_patch_size
|
||||
|
||||
bsz = len(x)
|
||||
device = x[0].device
|
||||
t = t * self.t_scale
|
||||
t = self.t_embedder(t)
|
||||
|
||||
(
|
||||
x,
|
||||
cap_feats,
|
||||
x_size,
|
||||
x_pos_ids,
|
||||
cap_pos_ids,
|
||||
x_inner_pad_mask,
|
||||
cap_inner_pad_mask,
|
||||
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
|
||||
|
||||
x_item_seqlens = [len(_) for _ in x]
|
||||
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
|
||||
x_max_item_seqlen = max(x_item_seqlens)
|
||||
|
||||
control_context = self.patchify(control_context, patch_size, f_patch_size)
|
||||
control_context = torch.cat(control_context, dim=0)
|
||||
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context)
|
||||
|
||||
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
control_context = list(control_context.split(x_item_seqlens, dim=0))
|
||||
|
||||
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
|
||||
|
||||
# x embed & refine
|
||||
x = torch.cat(x, dim=0)
|
||||
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
|
||||
|
||||
# Match t_embedder output dtype to x for layerwise casting compatibility
|
||||
adaln_input = t.type_as(x)
|
||||
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
|
||||
x = list(x.split(x_item_seqlens, dim=0))
|
||||
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0))
|
||||
|
||||
x = pad_sequence(x, batch_first=True, padding_value=0.0)
|
||||
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
x_freqs_cis = x_freqs_cis[:, : x.shape[1]]
|
||||
|
||||
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(x_item_seqlens):
|
||||
x_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if self.add_control_noise_refiner is not None:
|
||||
if self.add_control_noise_refiner == "control_layers":
|
||||
layers = self.control_layers
|
||||
elif self.add_control_noise_refiner == "control_noise_refiner":
|
||||
layers = self.control_noise_refiner
|
||||
else:
|
||||
raise ValueError(f"Unsupported `add_control_noise_refiner` type: {self.add_control_noise_refiner}.")
|
||||
for layer in layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
control_context = self._gradient_checkpointing_func(
|
||||
layer, control_context, x, x_attn_mask, x_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
control_context = layer(control_context, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
hints = torch.unbind(control_context)[:-1]
|
||||
control_context = torch.unbind(control_context)[-1]
|
||||
noise_refiner_block_samples = {
|
||||
layer_idx: hints[idx] * conditioning_scale
|
||||
for idx, layer_idx in enumerate(self.control_refiner_layers_places)
|
||||
}
|
||||
else:
|
||||
noise_refiner_block_samples = None
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
if noise_refiner_block_samples is not None:
|
||||
if layer_idx in noise_refiner_block_samples:
|
||||
x = x + noise_refiner_block_samples[layer_idx]
|
||||
else:
|
||||
for layer_idx, layer in enumerate(self.noise_refiner):
|
||||
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
if noise_refiner_block_samples is not None:
|
||||
if layer_idx in noise_refiner_block_samples:
|
||||
x = x + noise_refiner_block_samples[layer_idx]
|
||||
|
||||
# cap embed & refine
|
||||
cap_item_seqlens = [len(_) for _ in cap_feats]
|
||||
cap_max_item_seqlen = max(cap_item_seqlens)
|
||||
|
||||
cap_feats = torch.cat(cap_feats, dim=0)
|
||||
cap_feats = self.cap_embedder(cap_feats)
|
||||
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
|
||||
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
|
||||
cap_freqs_cis = list(
|
||||
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
|
||||
)
|
||||
|
||||
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
|
||||
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
|
||||
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
|
||||
|
||||
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(cap_item_seqlens):
|
||||
cap_attn_mask[i, :seq_len] = 1
|
||||
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
else:
|
||||
for layer in self.context_refiner:
|
||||
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
|
||||
|
||||
# unified
|
||||
unified = []
|
||||
unified_freqs_cis = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
|
||||
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
|
||||
assert unified_item_seqlens == [len(_) for _ in unified]
|
||||
unified_max_item_seqlen = max(unified_item_seqlens)
|
||||
|
||||
unified = pad_sequence(unified, batch_first=True, padding_value=0.0)
|
||||
unified_freqs_cis = pad_sequence(unified_freqs_cis, batch_first=True, padding_value=0.0)
|
||||
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
|
||||
for i, seq_len in enumerate(unified_item_seqlens):
|
||||
unified_attn_mask[i, :seq_len] = 1
|
||||
|
||||
## ControlNet start
|
||||
if not self.add_control_noise_refiner:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = self._gradient_checkpointing_func(
|
||||
layer, control_context, x_attn_mask, x_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
for layer in self.control_noise_refiner:
|
||||
control_context = layer(control_context, x_attn_mask, x_freqs_cis, adaln_input)
|
||||
|
||||
# unified
|
||||
control_context_unified = []
|
||||
for i in range(bsz):
|
||||
x_len = x_item_seqlens[i]
|
||||
cap_len = cap_item_seqlens[i]
|
||||
control_context_unified.append(torch.cat([control_context[i][:x_len], cap_feats[i][:cap_len]]))
|
||||
control_context_unified = pad_sequence(control_context_unified, batch_first=True, padding_value=0.0)
|
||||
|
||||
for layer in self.control_layers:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
control_context_unified = self._gradient_checkpointing_func(
|
||||
layer, control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
else:
|
||||
control_context_unified = layer(
|
||||
control_context_unified, unified, unified_attn_mask, unified_freqs_cis, adaln_input
|
||||
)
|
||||
|
||||
hints = torch.unbind(control_context_unified)[:-1]
|
||||
controlnet_block_samples = {
|
||||
layer_idx: hints[idx] * conditioning_scale for idx, layer_idx in enumerate(self.control_layers_places)
|
||||
}
|
||||
return controlnet_block_samples
|
||||
@@ -47,7 +47,6 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -355,9 +354,8 @@ def _load_shard_file(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries, disable_mmap=disable_mmap)
|
||||
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
@@ -403,7 +401,6 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=None,
|
||||
ignore_mismatched_sizes=False,
|
||||
low_cpu_mem_usage=False,
|
||||
disable_mmap=False,
|
||||
):
|
||||
# Do not spawn anymore workers than you need
|
||||
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
||||
@@ -430,15 +427,10 @@ def _load_shard_files_with_threadpool(
|
||||
state_dict_folder=state_dict_folder,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
disable_mmap=disable_mmap,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
tqdm_kwargs["disable"] = True
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user