mirror of
https://github.com/huggingface/diffusers.git
synced 2026-02-04 18:05:17 +08:00
Compare commits
42 Commits
paulinebm-
...
modular-wa
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ce58446a36 | ||
|
|
3833ca425f | ||
|
|
02c7adc356 | ||
|
|
a3cc0e7a52 | ||
|
|
2a6cdc0b3e | ||
|
|
1791306739 | ||
|
|
df6516a716 | ||
|
|
5794ffffbe | ||
|
|
4fb44bdf91 | ||
|
|
b7a81582ae | ||
|
|
4b64b5603f | ||
|
|
2bb640f8ea | ||
|
|
2dc9d2af50 | ||
|
|
57e57cfae0 | ||
|
|
644169433f | ||
|
|
632765a5ee | ||
|
|
d36564f06a | ||
|
|
441b69eabf | ||
|
|
d568c9773f | ||
|
|
3981c955ce | ||
|
|
1903383e94 | ||
|
|
08f8b7af9a | ||
|
|
2f66edc880 | ||
|
|
be38f41f9f | ||
|
|
91e5134175 | ||
|
|
a812c87465 | ||
|
|
8b9f817ef5 | ||
|
|
b1f06b780a | ||
|
|
8600b4c10d | ||
|
|
c10bdd9b73 | ||
|
|
dab000e88b | ||
|
|
9fb6b89d49 | ||
|
|
6fb4c99f5a | ||
|
|
961b9b27d3 | ||
|
|
8f30bfff1f | ||
|
|
b4be29bda2 | ||
|
|
98479a94c2 | ||
|
|
ade1059ae2 | ||
|
|
41a6e86faf | ||
|
|
9b5a244653 | ||
|
|
417f6b2d33 | ||
|
|
e46354d2d0 |
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
4
.github/workflows/build_docker_images.yml
vendored
4
.github/workflows/build_docker_images.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Login to Docker Hub
|
||||
|
||||
4
.github/workflows/build_pr_documentation.yml
vendored
4
.github/workflows/build_pr_documentation.yml
vendored
@@ -17,10 +17,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ jobs:
|
||||
# If ref is 'refs/heads/main' => set 'main'
|
||||
# Else it must be a tag => set {tag}
|
||||
- name: Set checkout_ref and path_in_repo
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
@@ -65,13 +66,13 @@ jobs:
|
||||
run: |
|
||||
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
|
||||
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ env.CHECKOUT_REF }}
|
||||
|
||||
# Setup + install dependencies
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
|
||||
46
.github/workflows/nightly_tests.yml
vendored
46
.github/workflows/nightly_tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -119,7 +119,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file, examples]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_${{ matrix.module }}_cuda_test_reports
|
||||
path: reports
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -263,7 +263,7 @@ jobs:
|
||||
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_big_gpu_test_reports
|
||||
path: reports
|
||||
@@ -280,7 +280,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -321,7 +321,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -355,7 +355,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -391,7 +391,7 @@ jobs:
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.config.backend }}_reports
|
||||
path: reports
|
||||
@@ -408,7 +408,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -441,7 +441,7 @@ jobs:
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
@@ -466,7 +466,7 @@ jobs:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -474,7 +474,7 @@ jobs:
|
||||
run: mkdir -p combined_reports
|
||||
|
||||
- name: Download all test reports
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
@@ -500,7 +500,7 @@ jobs:
|
||||
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Upload consolidated report
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: consolidated_test_report
|
||||
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
|
||||
@@ -514,7 +514,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -554,7 +554,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v4
|
||||
# uses: actions/upload-artifact@v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -610,7 +610,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v4
|
||||
# uses: actions/upload-artifact@v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
|
||||
@@ -10,10 +10,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
|
||||
4
.github/workflows/pr_dependency_test.yml
vendored
4
.github/workflows/pr_dependency_test.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
check_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
38
.github/workflows/pr_modular_tests.yml
vendored
38
.github/workflows/pr_modular_tests.yml
vendored
@@ -1,3 +1,4 @@
|
||||
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
@@ -35,9 +36,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -55,9 +56,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -77,23 +78,13 @@ jobs:
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
group: aws-highmemory-32-plus
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
@@ -102,7 +93,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -118,22 +109,19 @@ jobs:
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--make-reports=tests_torch_cpu_modular_pipelines \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
|
||||
12
.github/workflows/pr_test_fetcher.yml
vendored
12
.github/workflows/pr_test_fetcher.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
test_map: ${{ steps.set_matrix.outputs.test_map }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
run: |
|
||||
python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
- name: Report fetched tests
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test_fetched
|
||||
path: test_preparation.txt
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: ${{ matrix.modules }}_test_reports
|
||||
path: reports
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -164,7 +164,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
20
.github/workflows/pr_tests.yml
vendored
20
.github/workflows/pr_tests.yml
vendored
@@ -31,9 +31,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -51,9 +51,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -153,7 +153,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -185,7 +185,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -236,7 +236,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -273,7 +273,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
24
.github/workflows/pr_tests_gpu.yml
vendored
24
.github/workflows/pr_tests_gpu.yml
vendored
@@ -32,9 +32,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -52,9 +52,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -170,7 +170,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -193,7 +193,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -239,7 +239,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -255,7 +255,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -287,7 +287,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -18,9 +18,9 @@ jobs:
|
||||
check_torch_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
24
.github/workflows/push_tests.yml
vendored
24
.github/workflows/push_tests.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -240,7 +240,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -286,7 +286,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_fast.yml
vendored
4
.github/workflows/push_tests_fast.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_mps.yml
vendored
4
.github/workflows/push_tests_mps.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_torch_mps_test_reports
|
||||
path: reports
|
||||
|
||||
8
.github/workflows/pypi_publish.yaml
vendored
8
.github/workflows/pypi_publish.yaml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
@@ -40,12 +40,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
28
.github/workflows/release_tests_fast.yml
vendored
28
.github/workflows/release_tests_fast.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -116,7 +116,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -205,7 +205,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -264,7 +264,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -305,7 +305,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -336,7 +336,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
2
.github/workflows/run_tests_from_a_pr.yml
vendored
2
.github/workflows/run_tests_from_a_pr.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
shell: bash -e {0}
|
||||
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: refs/pull/${{ inputs.pr_number }}/head
|
||||
|
||||
|
||||
2
.github/workflows/ssh-pr-runner.yml
vendored
2
.github/workflows/ssh-pr-runner.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
2
.github/workflows/ssh-runner.yml
vendored
2
.github/workflows/ssh-runner.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
4
.github/workflows/stale.yml
vendored
4
.github/workflows/stale.yml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
|
||||
2
.github/workflows/trufflehog.yml
vendored
2
.github/workflows/trufflehog.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
|
||||
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.12.4
|
||||
|
||||
2
.github/workflows/update_metadata.yml
vendored
2
.github/workflows/update_metadata.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup environment
|
||||
run: |
|
||||
|
||||
@@ -54,6 +54,8 @@
|
||||
title: Batch inference
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
- local: hybrid_inference/overview
|
||||
title: Remote inference
|
||||
title: Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
@@ -88,17 +90,6 @@
|
||||
title: FreeU
|
||||
title: Community optimizations
|
||||
title: Inference optimization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
- local: hybrid_inference/vae_decode
|
||||
title: VAE Decode
|
||||
- local: hybrid_inference/vae_encode
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: modular_diffusers/overview
|
||||
@@ -270,6 +261,8 @@
|
||||
title: Outputs
|
||||
- local: api/quantization
|
||||
title: Quantization
|
||||
- local: hybrid_inference/api_reference
|
||||
title: Remote inference
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
title: Main Classes
|
||||
@@ -367,6 +360,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/longcat_image_transformer2d
|
||||
title: LongCatImageTransformer2DModel
|
||||
- local: api/models/ltx2_video_transformer3d
|
||||
title: LTX2VideoTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
@@ -443,6 +438,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
title: AutoencoderKLLTX2Video
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -678,6 +677,8 @@
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx2
|
||||
title: LTX-2
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
|
||||
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Audio
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Audio
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Audio
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Video
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Video
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Video
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
|
||||
|
||||
## FluxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
|
||||
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
|
||||
@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
|
||||
|
||||
## SparseControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
|
||||
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
|
||||
|
||||
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
@@ -0,0 +1,26 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTX2VideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTX2VideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTX2VideoTransformer3DModel
|
||||
@@ -30,6 +30,10 @@
|
||||
|
||||
The ChronoEdit pipeline is developed by the ChronoEdit Team. The original code is available on [GitHub](https://github.com/nv-tlabs/ChronoEdit), and pretrained models can be found in the [nvidia/ChronoEdit](https://huggingface.co/collections/nvidia/chronoedit) collection on Hugging Face.
|
||||
|
||||
Available Models/LoRAs:
|
||||
- [nvidia/ChronoEdit-14B-Diffusers](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers)
|
||||
- [nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Upscaler-Lora)
|
||||
- [nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora](https://huggingface.co/nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora)
|
||||
|
||||
### Image Editing
|
||||
|
||||
@@ -100,6 +104,7 @@ Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.pn
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
@@ -109,9 +114,8 @@ image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encod
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
lora_path = hf_hub_download(repo_id=model_id, filename="lora/chronoedit_distill_lora.safetensors")
|
||||
pipe.load_lora_weights(lora_path)
|
||||
pipe.fuse_lora(lora_scale=1.0)
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
|
||||
pipe.fuse_lora(adapter_names=["distill"], lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
@@ -145,6 +149,57 @@ export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output.png")
|
||||
```
|
||||
|
||||
### Inference with Multiple LoRAs
|
||||
|
||||
```py
|
||||
import torch
|
||||
import numpy as np
|
||||
from diffusers import AutoencoderKLWan, ChronoEditTransformer3DModel, ChronoEditPipeline
|
||||
from diffusers.schedulers import UniPCMultistepScheduler
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
from transformers import CLIPVisionModel
|
||||
from PIL import Image
|
||||
|
||||
model_id = "nvidia/ChronoEdit-14B-Diffusers"
|
||||
image_encoder = CLIPVisionModel.from_pretrained(model_id, subfolder="image_encoder", torch_dtype=torch.float32)
|
||||
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
|
||||
transformer = ChronoEditTransformer3DModel.from_pretrained(model_id, subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
pipe = ChronoEditPipeline.from_pretrained(model_id, image_encoder=image_encoder, transformer=transformer, vae=vae, torch_dtype=torch.bfloat16)
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers-Paint-Brush-Lora", weight_name="paintbrush_lora_diffusers.safetensors", adapter_name="paintbrush")
|
||||
pipe.load_lora_weights("nvidia/ChronoEdit-14B-Diffusers", weight_name="lora/chronoedit_distill_lora.safetensors", adapter_name="distill")
|
||||
pipe.fuse_lora(adapter_names=["paintbrush", "distill"], lora_scale=1.0)
|
||||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=2.0)
|
||||
pipe.to("cuda")
|
||||
|
||||
image = load_image(
|
||||
"https://raw.githubusercontent.com/nv-tlabs/ChronoEdit/refs/heads/main/assets/images/input_paintbrush.png"
|
||||
)
|
||||
max_area = 720 * 1280
|
||||
aspect_ratio = image.height / image.width
|
||||
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
|
||||
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
|
||||
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
|
||||
print("width", width, "height", height)
|
||||
image = image.resize((width, height))
|
||||
prompt = (
|
||||
"Turn the pencil sketch in the image into an actual object that is consistent with the image’s content. The user wants to change the sketch to a crown and a hat."
|
||||
)
|
||||
|
||||
output = pipe(
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=5,
|
||||
num_inference_steps=8,
|
||||
guidance_scale=1.0,
|
||||
enable_temporal_reasoning=False,
|
||||
num_temporal_reasoning_steps=0,
|
||||
).frames[0]
|
||||
export_to_video(output, "output.mp4", fps=16)
|
||||
Image.fromarray((output[-1] * 255).clip(0, 255).astype("uint8")).save("output_1.png")
|
||||
```
|
||||
|
||||
## ChronoEditPipeline
|
||||
|
||||
[[autodoc]] ChronoEditPipeline
|
||||
|
||||
43
docs/source/en/api/pipelines/ltx2.md
Normal file
43
docs/source/en/api/pipelines/ltx2.md
Normal file
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTX2ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2LatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTX2LatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
|
||||
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
|
||||
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
|
||||
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
|
||||
|
||||
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
|
||||
|
||||
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
<details>
|
||||
<summary>Show example code</summary>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
@@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
</details>
|
||||
|
||||
## LTXI2VLongMultiPromptPipeline
|
||||
|
||||
[[autodoc]] LTXI2VLongMultiPromptPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -95,7 +95,7 @@ image.save("qwen_fewsteps.png")
|
||||
|
||||
With [`QwenImageEditPlusPipeline`], one can provide multiple images as input reference.
|
||||
|
||||
```
|
||||
```py
|
||||
import torch
|
||||
from PIL import Image
|
||||
from diffusers import QwenImageEditPlusPipeline
|
||||
|
||||
@@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers:
|
||||
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
|
||||
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
|
||||
|
||||
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
> [!TIP]
|
||||
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
|
||||
|
||||
@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
# Hybrid Inference API Reference
|
||||
# Remote inference
|
||||
|
||||
## Remote Decode
|
||||
Remote inference provides access to an [Inference Endpoint](https://huggingface.co/docs/inference-endpoints/index) to offload local generation requirements for decoding and encoding.
|
||||
|
||||
## remote_decode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_decode
|
||||
|
||||
## Remote Encode
|
||||
## remote_encode
|
||||
|
||||
[[autodoc]] utils.remote_utils.remote_encode
|
||||
|
||||
@@ -10,51 +10,296 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Hybrid Inference
|
||||
|
||||
**Empowering local AI builders with Hybrid Inference**
|
||||
|
||||
# Remote inference
|
||||
|
||||
> [!TIP]
|
||||
> Hybrid Inference is an [experimental feature](https://huggingface.co/blog/remote_vae).
|
||||
> Feedback can be provided [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
> This is currently an experimental feature, and if you have any feedback, please feel free to leave it [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
Remote inference offloads the decoding and encoding process to a remote endpoint to relax the memory requirements for local inference with large models. This feature is powered by [Inference Endpoints](https://huggingface.co/docs/inference-endpoints/index). Refer to the table below for the supported models and endpoint.
|
||||
|
||||
| Model | Endpoint | Checkpoint | Support |
|
||||
|---|---|---|---|
|
||||
| Stable Diffusion v1 | https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud | [stabilityai/sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) | encode/decode |
|
||||
| Stable Diffusion XL | https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud | [madebyollin/sdxl-vae-fp16-fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix) | encode/decode |
|
||||
| Flux | https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud | [black-forest-labs/FLUX.1-schnell](https://huggingface.co/black-forest-labs/FLUX.1-schnell) | encode/decode |
|
||||
| HunyuanVideo | https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud | [hunyuanvideo-community/HunyuanVideo](https://huggingface.co/hunyuanvideo-community/HunyuanVideo) | decode |
|
||||
|
||||
This guide will show you how to encode and decode latents with remote inference.
|
||||
|
||||
## Encoding
|
||||
|
||||
Encoding converts images and videos into latent representations. Refer to the table below for the supported VAEs.
|
||||
|
||||
Pass an image to [`~utils.remote_encode`] to encode it. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
init_image = load_image(
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud",
|
||||
image=init_image,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159
|
||||
)
|
||||
```
|
||||
|
||||
## Decoding
|
||||
|
||||
Decoding converts latent representations back into images or videos. Refer to the table below for the available and supported VAEs.
|
||||
|
||||
Set the output type to `"latent"` in the pipeline and set the `vae` to `None`. Pass the latents to the [`~utils.remote_decode`] function. For Flux, the latents are packed so the `height` and `width` also need to be passed. The specific `scaling_factor` and `shift_factor` values for each model can be found in the [Remote inference](../hybrid_inference/api_reference) API reference.
|
||||
|
||||
<hfoptions id="decode">
|
||||
<hfoption id="Flux">
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
prompt = """
|
||||
A photorealistic Apollo-era photograph of a cat in a small astronaut suit with a bubble helmet, standing on the Moon and holding a flagpole planted in the dusty lunar soil. The flag shows a colorful paw-print emblem. Earth glows in the black sky above the stark gray surface, with sharp shadows and high-contrast lighting like vintage NASA photos.
|
||||
"""
|
||||
|
||||
latent = pipeline(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("image.jpg")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="HunyuanVideo">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
"hunyuanvideo-community/HunyuanVideo", subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16, device_map="cuda"
|
||||
)
|
||||
|
||||
latent = pipeline(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Queuing
|
||||
|
||||
Remote inference supports queuing to process multiple generation requests. While the current latent is being decoded, you can queue the next prompt.
|
||||
|
||||
```py
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.13025,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"A grainy Apollo-era style photograph of a cat in a snug astronaut suit with a bubble helmet, standing on the lunar surface and gripping a flag with a paw-print emblem. The gray Moon landscape stretches behind it, Earth glowing vividly in the black sky, shadows crisp and high-contrast.",
|
||||
"A vintage 1960s sci-fi pulp magazine cover illustration of a heroic cat astronaut planting a flag on the Moon. Bold, saturated colors, exaggerated space gear, playful typography floating in the background, Earth painted in bright blues and greens.",
|
||||
"A hyper-detailed cinematic shot of a cat astronaut on the Moon holding a fluttering flag, fur visible through the helmet glass, lunar dust scattering under its feet. The vastness of space and Earth in the distance create an epic, awe-inspiring tone.",
|
||||
"A colorful cartoon drawing of a happy cat wearing a chunky, oversized spacesuit, proudly holding a flag with a big paw print on it. The Moon’s surface is simplified with craters drawn like doodles, and Earth in the sky has a smiling face.",
|
||||
"A monochrome 1969-style press photo of a “first cat on the Moon” moment. The cat, in a tiny astronaut suit, stands by a planted flag, with grainy textures, scratches, and a blurred Earth in the background, mimicking old archival space photos."
|
||||
]
|
||||
|
||||
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
## Why use Hybrid Inference?
|
||||
pipeline.unet = pipeline.unet.to(memory_format=torch.channels_last)
|
||||
pipeline.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
Hybrid Inference offers a fast and simple way to offload local generation requirements.
|
||||
_ = pipeline(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
- 🚀 **Reduced Requirements:** Access powerful models without expensive hardware.
|
||||
- 💎 **Without Compromise:** Achieve the highest quality without sacrificing performance.
|
||||
- 💰 **Cost Effective:** It's free! 🤑
|
||||
- 🎯 **Diverse Use Cases:** Fully compatible with Diffusers 🧨 and the wider community.
|
||||
- 🔧 **Developer-Friendly:** Simple requests, fast responses.
|
||||
for prompt in prompts:
|
||||
latent = pipeline(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
|
||||
---
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
|
||||
## Available Models
|
||||
## Benchmarks
|
||||
|
||||
* **VAE Decode 🖼️:** Quickly decode latent representations into high-quality images without compromising performance or workflow speed.
|
||||
* **VAE Encode 🔢:** Efficiently encode images into latent representations for generation and training.
|
||||
* **Text Encoders 📃 (coming soon):** Compute text embeddings for your prompts quickly and accurately, ensuring a smooth and high-quality workflow.
|
||||
The tables demonstrate the memory requirements for encoding and decoding with Stable Diffusion v1.5 and SDXL on different GPUs.
|
||||
|
||||
---
|
||||
For the majority of these GPUs, the memory usage dictates whether other models (text encoders, UNet/transformer) need to be offloaded or required tiled encoding. The latter two techniques increases inference time and impacts quality.
|
||||
|
||||
## Integrations
|
||||
<details><summary>Encoding - Stable Diffusion v1.5</summary>
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
## Changelog
|
||||
</details>
|
||||
|
||||
- March 10 2025: Added VAE encode
|
||||
- March 2 2025: Initial release with VAE decoding
|
||||
<details><summary>Encoding SDXL</summary>
|
||||
|
||||
## Contents
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
The documentation is organized into three sections:
|
||||
</details>
|
||||
|
||||
* **VAE Decode** Learn the basics of how to use VAE Decode with Hybrid Inference.
|
||||
* **VAE Encode** Learn the basics of how to use VAE Encode with Hybrid Inference.
|
||||
* **API Reference** Dive into task-specific settings and parameters.
|
||||
<details><summary>Decoding - Stable Diffusion v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>Decoding SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## Resources
|
||||
|
||||
- Remote inference is also supported in [SD.Next](https://github.com/vladmandic/sdnext) and [ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae).
|
||||
- Refer to the [Remote VAEs for decoding with Inference Endpoints](https://huggingface.co/blog/remote_vae) blog post to learn more.
|
||||
@@ -1,345 +0,0 @@
|
||||
# Getting Started: VAE Decode with Hybrid Inference
|
||||
|
||||
VAE decode is an essential component of diffusion models - turning latent representations into images or videos.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE decode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled decoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.031 | 5.60% | 0.031 (0%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.148 | 20.00% | 0.301 (+103%) | 5.60% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.05 | 8.40% | 0.050 (0%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.224 | 30.00% | 0.356 (+59%) | 8.40% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.066 | 11.30% | 0.066 (0%) | 11.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.284 | 40.50% | 0.454 (+60%) | 11.40% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.062 | 5.20% | 0.062 (0%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.253 | 18.50% | 0.464 (+83%) | 5.20% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.07 | 12.80% | 0.070 (0%) | 12.80% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.286 | 45.30% | 0.466 (+63%) | 12.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.102 | 15.90% | 0.102 (0%) | 15.90% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.421 | 56.30% | 0.746 (+77%) | 16.00% |
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
| --- | --- | --- | --- | --- | --- |
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.057 | 10.00% | 0.057 (0%) | 10.00% |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.256 | 35.50% | 0.257 (+0.4%) | 35.50% |
|
||||
| NVIDIA GeForce RTX 4080 | 512x512 | 0.092 | 15.00% | 0.092 (0%) | 15.00% |
|
||||
| NVIDIA GeForce RTX 4080 | 1024x1024 | 0.406 | 53.30% | 0.406 (0%) | 53.30% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 512x512 | 0.121 | 20.20% | 0.120 (-0.8%) | 20.20% |
|
||||
| NVIDIA GeForce RTX 4070 Ti | 1024x1024 | 0.519 | 72.00% | 0.519 (0%) | 72.00% |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.107 | 10.50% | 0.107 (0%) | 10.50% |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.459 | 38.00% | 0.460 (+0.2%) | 38.00% |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.121 | 25.60% | 0.121 (0%) | 25.60% |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.524 | 93.00% | 0.524 (0%) | 93.00% |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.183 | 31.80% | 0.183 (0%) | 31.80% |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.794 | 96.40% | 0.794 (0%) | 96.40% |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud](https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud](https://x2dmsqunjd6k9prw.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud](https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
| **HunyuanVideo** | [https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud](https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud) | [`hunyuanvideo-community/HunyuanVideo`](https://hf.co/hunyuanvideo-community/HunyuanVideo) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Here, we show how to use the remote VAE on random tensors.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4, 64, 64], dtype=torch.float16),
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/output.png"/>
|
||||
</figure>
|
||||
|
||||
Usage for Flux is slightly different. Flux latents are packed so we need to send the `height` and `width`.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 4096, 64], dtype=torch.float16),
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/flux_random_latent.png"/>
|
||||
</figure>
|
||||
|
||||
Finally, an example for HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=torch.randn([1, 16, 3, 40, 64], dtype=torch.float16),
|
||||
output_type="mp4",
|
||||
)
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video_1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
But we want to use the VAE on an actual pipeline to get an actual image, not random noise. The example below shows how to do it with SD v1.5.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s another example with Flux.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipe = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-schnell",
|
||||
torch_dtype=torch.bfloat16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
prompt = "Strawberry ice cream, in a stylish modern glass, coconut, splashing milk cream and honey, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious"
|
||||
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
guidance_scale=0.0,
|
||||
num_inference_steps=4,
|
||||
output_type="latent",
|
||||
).images
|
||||
image = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
height=1024,
|
||||
width=1024,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
image.save("test.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/test_1.jpg"/>
|
||||
</figure>
|
||||
|
||||
Here’s an example with HunyuanVideo.
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
|
||||
|
||||
model_id = "hunyuanvideo-community/HunyuanVideo"
|
||||
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
|
||||
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe = HunyuanVideoPipeline.from_pretrained(
|
||||
model_id, transformer=transformer, vae=None, torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
latent = pipe(
|
||||
prompt="A cat walks on the grass, realistic",
|
||||
height=320,
|
||||
width=512,
|
||||
num_frames=61,
|
||||
num_inference_steps=30,
|
||||
output_type="latent",
|
||||
).frames
|
||||
|
||||
video = remote_decode(
|
||||
endpoint="https://o7ywnmrahorts457.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
output_type="mp4",
|
||||
)
|
||||
|
||||
if isinstance(video, bytes):
|
||||
with open("video.mp4", "wb") as f:
|
||||
f.write(video)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/video.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
|
||||
### Queueing
|
||||
|
||||
One of the great benefits of using a remote VAE is that we can queue multiple generation requests. While the current latent is being processed for decoding, we can already queue another one. This helps improve concurrency.
|
||||
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import queue
|
||||
import threading
|
||||
from IPython.display import display
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def decode_worker(q: queue.Queue):
|
||||
while True:
|
||||
item = q.get()
|
||||
if item is None:
|
||||
break
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=item,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
display(image)
|
||||
q.task_done()
|
||||
|
||||
q = queue.Queue()
|
||||
thread = threading.Thread(target=decode_worker, args=(q,), daemon=True)
|
||||
thread.start()
|
||||
|
||||
def decode(latent: torch.Tensor):
|
||||
q.put(latent)
|
||||
|
||||
prompts = [
|
||||
"Blueberry ice cream, in a stylish modern glass , ice cubes, nuts, mint leaves, splashing milk cream, in a gradient purple background, fluid motion, dynamic movement, cinematic lighting, Mysterious",
|
||||
"Lemonade in a glass, mint leaves, in an aqua and white background, flowers, ice cubes, halo, fluid motion, dynamic movement, soft lighting, digital painting, rule of thirds composition, Art by Greg rutkowski, Coby whitmore",
|
||||
"Comic book art, beautiful, vintage, pastel neon colors, extremely detailed pupils, delicate features, light on face, slight smile, Artgerm, Mary Blair, Edmund Dulac, long dark locks, bangs, glowing, fashionable style, fairytale ambience, hot pink.",
|
||||
"Masterpiece, vanilla cone ice cream garnished with chocolate syrup, crushed nuts, choco flakes, in a brown background, gold, cinematic lighting, Art by WLOP",
|
||||
"A bowl of milk, falling cornflakes, berries, blueberries, in a white background, soft lighting, intricate details, rule of thirds, octane render, volumetric lighting",
|
||||
"Cold Coffee with cream, crushed almonds, in a glass, choco flakes, ice cubes, wet, in a wooden background, cinematic lighting, hyper realistic painting, art by Carne Griffiths, octane render, volumetric lighting, fluid motion, dynamic movement, muted colors,",
|
||||
]
|
||||
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
"Lykon/dreamshaper-8",
|
||||
torch_dtype=torch.float16,
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
pipe.unet = pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
_ = pipe(
|
||||
prompt=prompts[0],
|
||||
output_type="latent",
|
||||
)
|
||||
|
||||
for prompt in prompts:
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
output_type="latent",
|
||||
).images
|
||||
decode(latent)
|
||||
|
||||
q.put(None)
|
||||
thread.join()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<video
|
||||
alt="queue.mp4"
|
||||
autoplay loop autobuffer muted playsinline
|
||||
>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/queue.mp4" type="video/mp4">
|
||||
</video>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -1,183 +0,0 @@
|
||||
# Getting Started: VAE Encode with Hybrid Inference
|
||||
|
||||
VAE encode is used for training, image-to-image and image-to-video - turning into images or videos into latent representations.
|
||||
|
||||
## Memory
|
||||
|
||||
These tables demonstrate the VRAM requirements for VAE encode with SD v1 and SD XL on different GPUs.
|
||||
|
||||
For the majority of these GPUs the memory usage % dictates other models (text encoders, UNet/Transformer) must be offloaded, or tiled encoding has to be used which increases time taken and impacts quality.
|
||||
|
||||
<details><summary>SD v1.5</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory (%) | Tiled Time (secs) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|-------------:|--------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.015 | 3.51901 | 0.015 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.004 | 1.3154 | 0.005 | 1.3154 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.402 | 47.1852 | 0.496 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.078 | 12.2658 | 0.094 | 3.51901 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.023 | 5.30105 | 0.023 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.006 | 1.98152 | 0.006 | 1.98152 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 0.574 | 71.08 | 0.656 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.111 | 18.4772 | 0.14 | 5.30105 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.032 | 3.52782 | 0.032 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.01 | 1.31869 | 0.009 | 1.31869 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 0.742 | 47.3033 | 0.954 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.136 | 12.2965 | 0.207 | 3.52782 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.036 | 8.51761 | 0.036 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.01 | 3.18387 | 0.01 | 3.18387 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | 0.863 | 86.7424 | 1.191 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.157 | 29.6888 | 0.227 | 8.51761 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.051 | 10.6941 | 0.051 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.015 | 3.99743 | 0.015 | 3.99743 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | 1.217 | 96.054 | 1.482 | 10.6941 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.223 | 37.2751 | 0.327 | 10.6941 |
|
||||
|
||||
|
||||
</details>
|
||||
|
||||
<details><summary>SDXL</summary>
|
||||
|
||||
| GPU | Resolution | Time (seconds) | Memory Consumed (%) | Tiled Time (seconds) | Tiled Memory (%) |
|
||||
|:------------------------------|:-------------|-----------------:|----------------------:|-----------------------:|-------------------:|
|
||||
| NVIDIA GeForce RTX 4090 | 512x512 | 0.029 | 4.95707 | 0.029 | 4.95707 |
|
||||
| NVIDIA GeForce RTX 4090 | 256x256 | 0.007 | 2.29666 | 0.007 | 2.29666 |
|
||||
| NVIDIA GeForce RTX 4090 | 2048x2048 | 0.873 | 66.3452 | 0.863 | 15.5649 |
|
||||
| NVIDIA GeForce RTX 4090 | 1024x1024 | 0.142 | 15.5479 | 0.143 | 15.5479 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 512x512 | 0.044 | 7.46735 | 0.044 | 7.46735 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 256x256 | 0.01 | 3.4597 | 0.01 | 3.4597 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 2048x2048 | 1.317 | 87.1615 | 1.291 | 23.447 |
|
||||
| NVIDIA GeForce RTX 4080 SUPER | 1024x1024 | 0.213 | 23.4215 | 0.214 | 23.4215 |
|
||||
| NVIDIA GeForce RTX 3090 | 512x512 | 0.058 | 5.65638 | 0.058 | 5.65638 |
|
||||
| NVIDIA GeForce RTX 3090 | 256x256 | 0.016 | 2.45081 | 0.016 | 2.45081 |
|
||||
| NVIDIA GeForce RTX 3090 | 2048x2048 | 1.755 | 77.8239 | 1.614 | 18.4193 |
|
||||
| NVIDIA GeForce RTX 3090 | 1024x1024 | 0.265 | 18.4023 | 0.265 | 18.4023 |
|
||||
| NVIDIA GeForce RTX 3080 | 512x512 | 0.064 | 13.6568 | 0.064 | 13.6568 |
|
||||
| NVIDIA GeForce RTX 3080 | 256x256 | 0.018 | 5.91728 | 0.018 | 5.91728 |
|
||||
| NVIDIA GeForce RTX 3080 | 2048x2048 | OOM | OOM | 1.866 | 44.4717 |
|
||||
| NVIDIA GeForce RTX 3080 | 1024x1024 | 0.302 | 44.4308 | 0.302 | 44.4308 |
|
||||
| NVIDIA GeForce RTX 3070 | 512x512 | 0.093 | 17.1465 | 0.093 | 17.1465 |
|
||||
| NVIDIA GeForce RTX 3070 | 256x256 | 0.025 | 7.42931 | 0.026 | 7.42931 |
|
||||
| NVIDIA GeForce RTX 3070 | 2048x2048 | OOM | OOM | 2.674 | 55.8355 |
|
||||
| NVIDIA GeForce RTX 3070 | 1024x1024 | 0.443 | 55.7841 | 0.443 | 55.7841 |
|
||||
|
||||
</details>
|
||||
|
||||
## Available VAEs
|
||||
|
||||
| | **Endpoint** | **Model** |
|
||||
|:-:|:-----------:|:--------:|
|
||||
| **Stable Diffusion v1** | [https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud](https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud) | [`stabilityai/sd-vae-ft-mse`](https://hf.co/stabilityai/sd-vae-ft-mse) |
|
||||
| **Stable Diffusion XL** | [https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud](https://xjqqhmyn62rog84g.us-east-1.aws.endpoints.huggingface.cloud) | [`madebyollin/sdxl-vae-fp16-fix`](https://hf.co/madebyollin/sdxl-vae-fp16-fix) |
|
||||
| **Flux** | [https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud](https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud) | [`black-forest-labs/FLUX.1-schnell`](https://hf.co/black-forest-labs/FLUX.1-schnell) |
|
||||
|
||||
|
||||
> [!TIP]
|
||||
> Model support can be requested [here](https://github.com/huggingface/diffusers/issues/new?template=remote-vae-pilot-feedback.yml).
|
||||
|
||||
|
||||
## Code
|
||||
|
||||
> [!TIP]
|
||||
> Install `diffusers` from `main` to run the code: `pip install git+https://github.com/huggingface/diffusers@main`
|
||||
|
||||
|
||||
A helper method simplifies interacting with Hybrid Inference.
|
||||
|
||||
```python
|
||||
from diffusers.utils.remote_utils import remote_encode
|
||||
```
|
||||
|
||||
### Basic example
|
||||
|
||||
Let's encode an image, then decode it to demonstrate.
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg"/>
|
||||
</figure>
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/astronaut.jpg?download=true")
|
||||
|
||||
latent = remote_encode(
|
||||
endpoint="https://ptccx55jz97f9zgo.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
|
||||
decoded = remote_decode(
|
||||
endpoint="https://whhx50ex1aryqvw6.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.3611,
|
||||
shift_factor=0.1159,
|
||||
)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/decoded.png"/>
|
||||
</figure>
|
||||
|
||||
|
||||
### Generation
|
||||
|
||||
Now let's look at a generation example, we'll encode the image, generate then remotely decode too!
|
||||
|
||||
<details><summary>Code</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers.utils import load_image
|
||||
from diffusers.utils.remote_utils import remote_decode, remote_encode
|
||||
|
||||
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
vae=None,
|
||||
).to("cuda")
|
||||
|
||||
init_image = load_image(
|
||||
"https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
|
||||
)
|
||||
init_image = init_image.resize((768, 512))
|
||||
|
||||
init_latent = remote_encode(
|
||||
endpoint="https://qc6479g0aac6qwy9.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
image=init_image,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
|
||||
prompt = "A fantasy landscape, trending on artstation"
|
||||
latent = pipe(
|
||||
prompt=prompt,
|
||||
image=init_latent,
|
||||
strength=0.75,
|
||||
output_type="latent",
|
||||
).images
|
||||
|
||||
image = remote_decode(
|
||||
endpoint="https://q1bj3bpq6kzilnsu.us-east-1.aws.endpoints.huggingface.cloud/",
|
||||
tensor=latent,
|
||||
scaling_factor=0.18215,
|
||||
)
|
||||
image.save("fantasy_landscape.jpg")
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<figure class="image flex flex-col items-center justify-center text-center m-0 w-full">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/remote_vae/fantasy_landscape.png"/>
|
||||
</figure>
|
||||
|
||||
## Integrations
|
||||
|
||||
* **[SD.Next](https://github.com/vladmandic/sdnext):** All-in-one UI with direct supports Hybrid Inference.
|
||||
* **[ComfyUI-HFRemoteVae](https://github.com/kijai/ComfyUI-HFRemoteVae):** ComfyUI node for Hybrid Inference.
|
||||
@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
@@ -1929,6 +1929,8 @@ def main(args):
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
# Store vae config before potential deletion
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
@@ -1940,6 +1942,8 @@ def main(args):
|
||||
del vae
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
else:
|
||||
vae_scaling_factor = vae.config.scaling_factor
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
|
||||
@@ -2109,13 +2113,13 @@ def main(args):
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
|
||||
if latents_mean is None and latents_std is None:
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
model_input = model_input * vae_scaling_factor
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
else:
|
||||
latents_mean = latents_mean.to(device=model_input.device, dtype=model_input.dtype)
|
||||
latents_std = latents_std.to(device=model_input.device, dtype=model_input.dtype)
|
||||
model_input = (model_input - latents_mean) * vae.config.scaling_factor / latents_std
|
||||
model_input = (model_input - latents_mean) * vae_scaling_factor / latents_std
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### FSDP Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, it distributes the memory cost across multiple nodes.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### FSDP on the transformer
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
|
||||
```shell
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: HYBRID_SHARD
|
||||
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_sync_module_states: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
@@ -111,6 +111,25 @@ To better track our training experiments, we're using the following flags in the
|
||||
|
||||
## Notes
|
||||
|
||||
### LoRA Rank and Alpha
|
||||
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
|
||||
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
|
||||
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
|
||||
- lora_alpha vs. rank:
|
||||
This ratio dictates the LoRA's effective strength:
|
||||
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
|
||||
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
|
||||
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
|
||||
|
||||
> [!TIP]
|
||||
> A common starting point is to set `lora_alpha` equal to `rank`.
|
||||
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
|
||||
> to give the LoRA updates more influence without increasing parameter count.
|
||||
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
|
||||
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
|
||||
|
||||
### Additional CLI arguments
|
||||
|
||||
Additionally, we welcome you to explore the following CLI arguments:
|
||||
|
||||
* `--lora_layers`: The transformer modules to apply LoRA training on. Please specify the layers in a comma separated. E.g. - "to_k,to_q,to_v" will result in lora training of attention layers only.
|
||||
|
||||
@@ -44,6 +44,7 @@ import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -75,13 +76,16 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1219,7 +1227,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1263,17 +1275,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1285,13 +1322,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1507,6 +1551,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1536,6 +1595,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1777,7 +1838,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1836,15 +1897,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -43,6 +43,7 @@ import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -74,13 +75,16 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -339,7 +346,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -827,15 +835,28 @@ class DreamBoothDataset(Dataset):
|
||||
dest_image = self.cond_images[i]
|
||||
image_width, image_height = dest_image.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
image_width, image_height = dest_image.size
|
||||
|
||||
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
dest_image = Flux2ImageProcessor.image_processor.preprocess(
|
||||
image_processor = Flux2ImageProcessor()
|
||||
dest_image = image_processor.preprocess(
|
||||
dest_image, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
# Convert back to PIL
|
||||
dest_image = dest_image.squeeze(0)
|
||||
if dest_image.min() < 0:
|
||||
dest_image = (dest_image + 1) / 2
|
||||
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
|
||||
|
||||
if dest_image.shape[0] == 1:
|
||||
# Gray scale image
|
||||
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
|
||||
else:
|
||||
# RGB scale image: (C, H, W) -> (H, W, C)
|
||||
dest_image = TF.to_pil_image(dest_image)
|
||||
|
||||
dest_image = exif_transpose(dest_image)
|
||||
if not dest_image.mode == "RGB":
|
||||
@@ -1156,7 +1177,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1200,17 +1225,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1222,13 +1272,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1419,9 +1476,9 @@ def main(args):
|
||||
args.instance_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.validation_prompt is not None:
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.remote_text_encoder:
|
||||
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
|
||||
else:
|
||||
@@ -1430,6 +1487,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1461,6 +1533,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1700,7 +1774,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1759,15 +1833,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
157
examples/research_projects/lpl/README.md
Normal file
157
examples/research_projects/lpl/README.md
Normal file
@@ -0,0 +1,157 @@
|
||||
# Latent Perceptual Loss (LPL) for Stable Diffusion XL
|
||||
|
||||
This directory contains an implementation of Latent Perceptual Loss (LPL) for training Stable Diffusion XL models, based on the paper: [Boosting Latent Diffusion with Perceptual Objectives](https://huggingface.co/papers/2411.04873) (Berrada et al., 2025). LPL is a perceptual loss that operates in the latent space of a VAE, helping to improve the quality and consistency of generated images by bridging the disconnect between the diffusion model and the autoencoder decoder. The implementation is based on the reference implementation provided by Tariq Berrada.
|
||||
|
||||
## Overview
|
||||
|
||||
LPL addresses a key limitation in latent diffusion models (LDMs): the disconnect between the diffusion model training and the autoencoder decoder. While LDMs train in the latent space, they don't receive direct feedback about how well their outputs decode into high-quality images. This can lead to:
|
||||
|
||||
- Loss of fine details in generated images
|
||||
- Inconsistent image quality
|
||||
- Structural artifacts
|
||||
- Reduced sharpness and realism
|
||||
|
||||
LPL works by comparing intermediate features from the VAE decoder between the predicted and target latents. This helps the model learn better perceptual features and can lead to:
|
||||
|
||||
- Improved image quality and consistency (6-20% FID improvement)
|
||||
- Better preservation of fine details
|
||||
- More stable training, especially at high noise levels
|
||||
- Better handling of structural information
|
||||
- Sharper and more realistic textures
|
||||
|
||||
## Implementation Details
|
||||
|
||||
The LPL implementation follows the paper's methodology and includes several key features:
|
||||
|
||||
1. **Feature Extraction**: Extracts intermediate features from the VAE decoder, including:
|
||||
- Middle block features
|
||||
- Up block features (configurable number of blocks)
|
||||
- Proper gradient checkpointing for memory efficiency
|
||||
- Features are extracted only for timesteps below the threshold (high SNR)
|
||||
|
||||
2. **Feature Normalization**: Multiple normalization options as validated in the paper:
|
||||
- `default`: Normalize each feature map independently
|
||||
- `shared`: Cross-normalize features using target statistics (recommended)
|
||||
- `batch`: Batch-wise normalization
|
||||
|
||||
3. **Outlier Handling**: Optional removal of outliers in feature maps using:
|
||||
- Quantile-based filtering (2% quantiles)
|
||||
- Morphological operations (opening/closing)
|
||||
- Adaptive thresholding based on standard deviation
|
||||
|
||||
4. **Loss Types**:
|
||||
- MSE loss (default)
|
||||
- L1 loss
|
||||
- Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
## Usage
|
||||
|
||||
To use LPL in your training, add the following arguments to your training command:
|
||||
|
||||
```bash
|
||||
python examples/research_projects/lpl/train_sdxl_lpl.py \
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \ # Weight for LPL loss (1.0-2.0 recommended)
|
||||
--lpl_t_threshold 200 \ # Apply LPL only for timesteps < threshold (high SNR)
|
||||
--lpl_loss_type mse \ # Loss type: "mse" or "l1"
|
||||
--lpl_norm_type shared \ # Normalization type: "default", "shared" (recommended), or "batch"
|
||||
--lpl_pow_law \ # Use power law weighting for layers
|
||||
--lpl_num_blocks 4 \ # Number of up blocks to use (1-4)
|
||||
--lpl_remove_outliers \ # Remove outliers in feature maps
|
||||
--lpl_scale \ # Scale LPL loss by noise level weights
|
||||
--lpl_start 0 \ # Step to start applying LPL
|
||||
# ... other training arguments ...
|
||||
```
|
||||
|
||||
### Key Parameters
|
||||
|
||||
- `lpl_weight`: Controls the strength of the LPL loss relative to the main diffusion loss. Higher values (1.0-2.0) improve quality but may slow training.
|
||||
- `lpl_t_threshold`: LPL is only applied for timesteps below this threshold (high SNR). Lower values (100-200) focus on more important timesteps.
|
||||
- `lpl_loss_type`: Choose between MSE (default) and L1 loss. MSE is recommended for most cases.
|
||||
- `lpl_norm_type`: Feature normalization strategy. "shared" is recommended as it showed best results in the paper.
|
||||
- `lpl_pow_law`: Whether to use power law weighting (2^(-i) for layer i). Recommended for better feature balance.
|
||||
- `lpl_num_blocks`: Number of up blocks to use for feature extraction (1-4). More blocks capture more features but use more memory.
|
||||
- `lpl_remove_outliers`: Whether to remove outliers in feature maps. Recommended for stable training.
|
||||
- `lpl_scale`: Whether to scale LPL loss by noise level weights. Helps focus on more important timesteps.
|
||||
- `lpl_start`: Training step to start applying LPL. Can be used to warm up training.
|
||||
|
||||
## Recommendations
|
||||
|
||||
1. **Starting Point** (based on paper results):
|
||||
```bash
|
||||
--use_lpl \
|
||||
--lpl_weight 1.0 \
|
||||
--lpl_t_threshold 200 \
|
||||
--lpl_loss_type mse \
|
||||
--lpl_norm_type shared \
|
||||
--lpl_pow_law \
|
||||
--lpl_num_blocks 4 \
|
||||
--lpl_remove_outliers \
|
||||
--lpl_scale
|
||||
```
|
||||
|
||||
2. **Memory Efficiency**:
|
||||
- Use `--gradient_checkpointing` for memory efficiency (enabled by default)
|
||||
- Reduce `lpl_num_blocks` if memory is constrained (2-3 blocks still give good results)
|
||||
- Consider using `--lpl_scale` to focus on more important timesteps
|
||||
- Features are extracted only for timesteps below threshold to save memory
|
||||
|
||||
3. **Quality vs Speed**:
|
||||
- Higher `lpl_weight` (1.0-2.0) for better quality
|
||||
- Lower `lpl_t_threshold` (100-200) for faster training
|
||||
- Use `lpl_remove_outliers` for more stable training
|
||||
- `lpl_norm_type shared` provides best quality/speed trade-off
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Feature Extraction
|
||||
|
||||
The LPL implementation extracts features from the VAE decoder in the following order:
|
||||
1. Middle block output
|
||||
2. Up block outputs (configurable number of blocks)
|
||||
|
||||
Each feature map is processed with:
|
||||
1. Optional outlier removal (2% quantiles, morphological operations)
|
||||
2. Feature normalization (shared statistics recommended)
|
||||
3. Loss calculation (MSE or L1)
|
||||
4. Optional power law weighting (2^(-i) for layer i)
|
||||
|
||||
### Loss Calculation
|
||||
|
||||
For each feature map:
|
||||
1. Features are normalized according to the chosen strategy
|
||||
2. Loss is calculated between normalized features
|
||||
3. Outliers are masked out (if enabled)
|
||||
4. Loss is weighted by layer depth (if power law enabled)
|
||||
5. Final loss is averaged across all layers
|
||||
|
||||
### Memory Considerations
|
||||
|
||||
- Gradient checkpointing is used by default
|
||||
- Features are extracted only for timesteps below the threshold
|
||||
- Outlier removal is done in-place to save memory
|
||||
- Feature normalization is done efficiently using vectorized operations
|
||||
- Memory usage scales linearly with number of blocks used
|
||||
|
||||
## Results
|
||||
|
||||
Based on the paper's findings, LPL provides:
|
||||
- 6-20% improvement in FID scores
|
||||
- Better preservation of fine details
|
||||
- More realistic textures and structures
|
||||
- Improved consistency across different resolutions
|
||||
- Better performance on both small and large datasets
|
||||
|
||||
## Citation
|
||||
|
||||
If you use this implementation in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@inproceedings{berrada2025boosting,
|
||||
title={Boosting Latent Diffusion with Perceptual Objectives},
|
||||
author={Tariq Berrada and Pietro Astolfi and Melissa Hall and Marton Havasi and Yohann Benchetrit and Adriana Romero-Soriano and Karteek Alahari and Michal Drozdzal and Jakob Verbeek},
|
||||
booktitle={The Thirteenth International Conference on Learning Representations},
|
||||
year={2025},
|
||||
url={https://openreview.net/forum?id=y4DtzADzd1}
|
||||
}
|
||||
```
|
||||
215
examples/research_projects/lpl/lpl_loss.py
Normal file
215
examples/research_projects/lpl/lpl_loss.py
Normal file
@@ -0,0 +1,215 @@
|
||||
# Copyright 2025 Berrada et al.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def normalize_tensor(in_feat, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(in_feat**2, dim=1, keepdim=True))
|
||||
return in_feat / (norm_factor + eps)
|
||||
|
||||
|
||||
def cross_normalize(input, target, eps=1e-10):
|
||||
norm_factor = torch.sqrt(torch.sum(target**2, dim=1, keepdim=True))
|
||||
return input / (norm_factor + eps), target / (norm_factor + eps)
|
||||
|
||||
|
||||
def remove_outliers(feat, down_f=1, opening=5, closing=3, m=100, quant=0.02):
|
||||
opening = int(np.ceil(opening / down_f))
|
||||
closing = int(np.ceil(closing / down_f))
|
||||
if opening == 2:
|
||||
opening = 3
|
||||
if closing == 2:
|
||||
closing = 1
|
||||
|
||||
# replace quantile with kth value here.
|
||||
feat_flat = feat.flatten(-2, -1)
|
||||
k1, k2 = int(feat_flat.shape[-1] * quant), int(feat_flat.shape[-1] * (1 - quant))
|
||||
q1 = feat_flat.kthvalue(k1, dim=-1).values[..., None, None]
|
||||
q2 = feat_flat.kthvalue(k2, dim=-1).values[..., None, None]
|
||||
|
||||
m = 2 * feat_flat.std(-1)[..., None, None].detach()
|
||||
mask = (q1 - m < feat) * (feat < q2 + m)
|
||||
|
||||
# dilate the mask.
|
||||
mask = nn.MaxPool2d(kernel_size=closing, stride=1, padding=(closing - 1) // 2)(mask.float()) # closing
|
||||
mask = (-nn.MaxPool2d(kernel_size=opening, stride=1, padding=(opening - 1) // 2)(-mask)).bool() # opening
|
||||
feat = feat * mask
|
||||
return mask, feat
|
||||
|
||||
|
||||
class LatentPerceptualLoss(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
loss_type="mse",
|
||||
grad_ckpt=True,
|
||||
pow_law=False,
|
||||
norm_type="default",
|
||||
num_mid_blocks=4,
|
||||
feature_type="feature",
|
||||
remove_outliers=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.vae = vae
|
||||
self.decoder = self.vae.decoder
|
||||
# Store scaling factors as tensors on the correct device
|
||||
device = next(self.vae.parameters()).device
|
||||
|
||||
# Get scaling factors with proper defaults and handle None values
|
||||
scale_factor = getattr(self.vae.config, "scaling_factor", None)
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", None)
|
||||
|
||||
# Convert to tensors with proper defaults
|
||||
self.scale = torch.tensor(1.0 if scale_factor is None else scale_factor, device=device)
|
||||
self.shift = torch.tensor(0.0 if shift_factor is None else shift_factor, device=device)
|
||||
|
||||
self.gradient_checkpointing = grad_ckpt
|
||||
self.pow_law = pow_law
|
||||
self.norm_type = norm_type.lower()
|
||||
self.outlier_mask = remove_outliers
|
||||
self.last_feature_stats = [] # Store feature statistics for logging
|
||||
|
||||
assert feature_type in ["feature", "image"]
|
||||
self.feature_type = feature_type
|
||||
|
||||
assert self.norm_type in ["default", "shared", "batch"]
|
||||
assert num_mid_blocks >= 0 and num_mid_blocks <= 4
|
||||
self.n_blocks = num_mid_blocks
|
||||
|
||||
assert loss_type in ["mse", "l1"]
|
||||
if loss_type == "mse":
|
||||
self.loss_fn = nn.MSELoss(reduction="none")
|
||||
elif loss_type == "l1":
|
||||
self.loss_fn = nn.L1Loss(reduction="none")
|
||||
|
||||
def get_features(self, z, latent_embeds=None, disable_grads=False):
|
||||
with torch.set_grad_enabled(not disable_grads):
|
||||
if self.gradient_checkpointing and not disable_grads:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(self.decoder.mid_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(up_block),
|
||||
sample,
|
||||
latent_embeds,
|
||||
use_reentrant=False,
|
||||
)
|
||||
features.append(sample)
|
||||
return features
|
||||
else:
|
||||
features = []
|
||||
upscale_dtype = next(iter(self.decoder.up_blocks.parameters())).dtype
|
||||
sample = z
|
||||
sample = self.decoder.conv_in(sample)
|
||||
|
||||
# middle
|
||||
sample = self.decoder.mid_block(sample, latent_embeds)
|
||||
sample = sample.to(upscale_dtype)
|
||||
features.append(sample)
|
||||
|
||||
# up
|
||||
for up_block in self.decoder.up_blocks[: self.n_blocks]:
|
||||
sample = up_block(sample, latent_embeds)
|
||||
features.append(sample)
|
||||
return features
|
||||
|
||||
def get_loss(self, input, target, get_hist=False):
|
||||
if self.feature_type == "feature":
|
||||
inp_f = self.get_features(self.shift + input / self.scale)
|
||||
tar_f = self.get_features(self.shift + target / self.scale, disable_grads=True)
|
||||
losses = []
|
||||
self.last_feature_stats = [] # Reset feature stats
|
||||
|
||||
for i, (x, y) in enumerate(zip(inp_f, tar_f, strict=False)):
|
||||
my = torch.ones_like(y).bool()
|
||||
outlier_ratio = 0.0
|
||||
|
||||
if self.outlier_mask:
|
||||
with torch.no_grad():
|
||||
if i == 2:
|
||||
my, y = remove_outliers(y, down_f=2)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
elif i in [3, 4, 5]:
|
||||
my, y = remove_outliers(y, down_f=1)
|
||||
outlier_ratio = 1.0 - my.float().mean().item()
|
||||
|
||||
# Store feature statistics before normalization
|
||||
with torch.no_grad():
|
||||
stats = {
|
||||
"mean": y.mean().item(),
|
||||
"std": y.std().item(),
|
||||
"outlier_ratio": outlier_ratio,
|
||||
}
|
||||
self.last_feature_stats.append(stats)
|
||||
|
||||
# normalize feature tensors
|
||||
if self.norm_type == "default":
|
||||
x = normalize_tensor(x)
|
||||
y = normalize_tensor(y)
|
||||
elif self.norm_type == "shared":
|
||||
x, y = cross_normalize(x, y, eps=1e-6)
|
||||
|
||||
term_loss = self.loss_fn(x, y) * my
|
||||
# reduce loss term
|
||||
loss_f = 2 ** (-min(i, 3)) if self.pow_law else 1.0
|
||||
term_loss = term_loss.sum((2, 3)) * loss_f / my.sum((2, 3))
|
||||
losses.append(term_loss.mean((1,)))
|
||||
|
||||
if get_hist:
|
||||
return losses
|
||||
else:
|
||||
loss = sum(losses)
|
||||
return loss / len(inp_f)
|
||||
elif self.feature_type == "image":
|
||||
inp_f = self.vae.decode(input / self.scale).sample
|
||||
tar_f = self.vae.decode(target / self.scale).sample
|
||||
return F.mse_loss(inp_f, tar_f)
|
||||
|
||||
def get_first_conv(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
return sample
|
||||
|
||||
def get_first_block(self, z):
|
||||
sample = self.decoder.conv_in(z)
|
||||
sample = self.decoder.mid_block(sample)
|
||||
for resnet in self.decoder.up_blocks[0].resnets:
|
||||
sample = resnet(sample, None)
|
||||
return sample
|
||||
|
||||
def get_first_layer(self, input, target, target_layer="conv"):
|
||||
if target_layer == "conv":
|
||||
feat_in = self.get_first_conv(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_conv(target)
|
||||
else:
|
||||
feat_in = self.get_first_block(input)
|
||||
with torch.no_grad():
|
||||
feat_tar = self.get_first_block(target)
|
||||
|
||||
feat_in, feat_tar = cross_normalize(feat_in, feat_tar)
|
||||
|
||||
return F.mse_loss(feat_in, feat_tar, reduction="mean")
|
||||
1622
examples/research_projects/lpl/train_sdxl_lpl.py
Normal file
1622
examples/research_projects/lpl/train_sdxl_lpl.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -7,16 +7,12 @@ import torch
|
||||
from diffusers.utils import logging
|
||||
|
||||
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
|
||||
from .wrappers import ThreadSafeImageProcessorWrapper, ThreadSafeTokenizerWrapper, ThreadSafeVAEWrapper
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def safe_tokenize(tokenizer, *args, lock, **kwargs):
|
||||
with lock:
|
||||
return tokenizer(*args, **kwargs)
|
||||
|
||||
|
||||
class RequestScopedPipeline:
|
||||
DEFAULT_MUTABLE_ATTRS = [
|
||||
"_all_hooks",
|
||||
@@ -38,23 +34,40 @@ class RequestScopedPipeline:
|
||||
wrap_scheduler: bool = True,
|
||||
):
|
||||
self._base = pipeline
|
||||
|
||||
self.unet = getattr(pipeline, "unet", None)
|
||||
self.vae = getattr(pipeline, "vae", None)
|
||||
self.text_encoder = getattr(pipeline, "text_encoder", None)
|
||||
self.components = getattr(pipeline, "components", None)
|
||||
|
||||
self.transformer = getattr(pipeline, "transformer", None)
|
||||
|
||||
if wrap_scheduler and hasattr(pipeline, "scheduler") and pipeline.scheduler is not None:
|
||||
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
|
||||
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
|
||||
|
||||
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
|
||||
|
||||
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
|
||||
|
||||
self._vae_lock = threading.Lock()
|
||||
self._image_lock = threading.Lock()
|
||||
|
||||
self._auto_detect_mutables = bool(auto_detect_mutables)
|
||||
self._tensor_numel_threshold = int(tensor_numel_threshold)
|
||||
|
||||
self._auto_detected_attrs: List[str] = []
|
||||
|
||||
def _detect_kernel_pipeline(self, pipeline) -> bool:
|
||||
kernel_indicators = [
|
||||
"text_encoding_cache",
|
||||
"memory_manager",
|
||||
"enable_optimizations",
|
||||
"_create_request_context",
|
||||
"get_optimization_stats",
|
||||
]
|
||||
|
||||
return any(hasattr(pipeline, attr) for attr in kernel_indicators)
|
||||
|
||||
def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str] = None, **clone_kwargs):
|
||||
base_sched = getattr(self._base, "scheduler", None)
|
||||
if base_sched is None:
|
||||
@@ -70,11 +83,21 @@ class RequestScopedPipeline:
|
||||
num_inference_steps=num_inference_steps, device=device, **clone_kwargs
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
|
||||
logger.debug(f"clone_for_request failed: {e}; trying shallow copy fallback")
|
||||
try:
|
||||
return copy.deepcopy(wrapped_scheduler)
|
||||
except Exception as e:
|
||||
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
|
||||
if hasattr(wrapped_scheduler, "scheduler"):
|
||||
try:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler.scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception:
|
||||
return wrapped_scheduler
|
||||
else:
|
||||
copied_scheduler = copy.copy(wrapped_scheduler)
|
||||
return BaseAsyncScheduler(copied_scheduler)
|
||||
except Exception as e2:
|
||||
logger.warning(
|
||||
f"Shallow copy of scheduler also failed: {e2}. Using original scheduler (*thread-unsafe but functional*)."
|
||||
)
|
||||
return wrapped_scheduler
|
||||
|
||||
def _autodetect_mutables(self, max_attrs: int = 40):
|
||||
@@ -86,6 +109,7 @@ class RequestScopedPipeline:
|
||||
|
||||
candidates: List[str] = []
|
||||
seen = set()
|
||||
|
||||
for name in dir(self._base):
|
||||
if name.startswith("__"):
|
||||
continue
|
||||
@@ -93,6 +117,7 @@ class RequestScopedPipeline:
|
||||
continue
|
||||
if name in ("to", "save_pretrained", "from_pretrained"):
|
||||
continue
|
||||
|
||||
try:
|
||||
val = getattr(self._base, name)
|
||||
except Exception:
|
||||
@@ -100,11 +125,9 @@ class RequestScopedPipeline:
|
||||
|
||||
import types
|
||||
|
||||
# skip callables and modules
|
||||
if callable(val) or isinstance(val, (types.ModuleType, types.FunctionType, types.MethodType)):
|
||||
continue
|
||||
|
||||
# containers -> candidate
|
||||
if isinstance(val, (dict, list, set, tuple, bytearray)):
|
||||
candidates.append(name)
|
||||
seen.add(name)
|
||||
@@ -205,6 +228,9 @@ class RequestScopedPipeline:
|
||||
|
||||
return has_tokenizer_methods and (has_tokenizer_in_name or has_tokenizer_attrs)
|
||||
|
||||
def _should_wrap_tokenizers(self) -> bool:
|
||||
return True
|
||||
|
||||
def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] = None, **kwargs):
|
||||
local_scheduler = self._make_local_scheduler(num_inference_steps=num_inference_steps, device=device)
|
||||
|
||||
@@ -214,6 +240,25 @@ class RequestScopedPipeline:
|
||||
logger.warning(f"copy.copy(self._base) failed: {e}. Falling back to deepcopy (may increase memory).")
|
||||
local_pipe = copy.deepcopy(self._base)
|
||||
|
||||
try:
|
||||
if (
|
||||
hasattr(local_pipe, "vae")
|
||||
and local_pipe.vae is not None
|
||||
and not isinstance(local_pipe.vae, ThreadSafeVAEWrapper)
|
||||
):
|
||||
local_pipe.vae = ThreadSafeVAEWrapper(local_pipe.vae, self._vae_lock)
|
||||
|
||||
if (
|
||||
hasattr(local_pipe, "image_processor")
|
||||
and local_pipe.image_processor is not None
|
||||
and not isinstance(local_pipe.image_processor, ThreadSafeImageProcessorWrapper)
|
||||
):
|
||||
local_pipe.image_processor = ThreadSafeImageProcessorWrapper(
|
||||
local_pipe.image_processor, self._image_lock
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not wrap vae/image_processor: {e}")
|
||||
|
||||
if local_scheduler is not None:
|
||||
try:
|
||||
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
|
||||
@@ -231,47 +276,42 @@ class RequestScopedPipeline:
|
||||
|
||||
self._clone_mutable_attrs(self._base, local_pipe)
|
||||
|
||||
# 4) wrap tokenizers on the local pipe with the lock wrapper
|
||||
tokenizer_wrappers = {} # name -> original_tokenizer
|
||||
try:
|
||||
# a) wrap direct tokenizer attributes (tokenizer, tokenizer_2, ...)
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
tokenizer_wrappers[name] = tok
|
||||
setattr(
|
||||
local_pipe,
|
||||
name,
|
||||
lambda *args, tok=tok, **kwargs: safe_tokenize(
|
||||
tok, *args, lock=self._tokenizer_lock, **kwargs
|
||||
),
|
||||
)
|
||||
original_tokenizers = {}
|
||||
|
||||
# b) wrap tokenizers in components dict
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
if self._should_wrap_tokenizers():
|
||||
try:
|
||||
for name in dir(local_pipe):
|
||||
if "tokenizer" in name and not name.startswith("_"):
|
||||
tok = getattr(local_pipe, name, None)
|
||||
if tok is not None and self._is_tokenizer_component(tok):
|
||||
if not isinstance(tok, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[name] = tok
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(tok, self._tokenizer_lock)
|
||||
setattr(local_pipe, name, wrapped_tokenizer)
|
||||
|
||||
if self._is_tokenizer_component(val):
|
||||
tokenizer_wrappers[f"components[{key}]"] = val
|
||||
local_pipe.components[key] = lambda *args, tokenizer=val, **kwargs: safe_tokenize(
|
||||
tokenizer, *args, lock=self._tokenizer_lock, **kwargs
|
||||
)
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
for key, val in local_pipe.components.items():
|
||||
if val is None:
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
if self._is_tokenizer_component(val):
|
||||
if not isinstance(val, ThreadSafeTokenizerWrapper):
|
||||
original_tokenizers[f"components[{key}]"] = val
|
||||
wrapped_tokenizer = ThreadSafeTokenizerWrapper(val, self._tokenizer_lock)
|
||||
local_pipe.components[key] = wrapped_tokenizer
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Tokenizer wrapping step encountered an error: {e}")
|
||||
|
||||
result = None
|
||||
cm = getattr(local_pipe, "model_cpu_offload_context", None)
|
||||
|
||||
try:
|
||||
if callable(cm):
|
||||
try:
|
||||
with cm():
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
except TypeError:
|
||||
# cm might be a context manager instance rather than callable
|
||||
try:
|
||||
with cm:
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
@@ -279,18 +319,18 @@ class RequestScopedPipeline:
|
||||
logger.debug(f"model_cpu_offload_context usage failed: {e}. Proceeding without it.")
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
else:
|
||||
# no offload context available — call directly
|
||||
result = local_pipe(*args, num_inference_steps=num_inference_steps, **kwargs)
|
||||
|
||||
return result
|
||||
|
||||
finally:
|
||||
try:
|
||||
for name, tok in tokenizer_wrappers.items():
|
||||
for name, tok in original_tokenizers.items():
|
||||
if name.startswith("components["):
|
||||
key = name[len("components[") : -1]
|
||||
local_pipe.components[key] = tok
|
||||
if hasattr(local_pipe, "components") and isinstance(local_pipe.components, dict):
|
||||
local_pipe.components[key] = tok
|
||||
else:
|
||||
setattr(local_pipe, name, tok)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error restoring wrapped tokenizers: {e}")
|
||||
logger.debug(f"Error restoring original tokenizers: {e}")
|
||||
|
||||
86
examples/server-async/utils/wrappers.py
Normal file
86
examples/server-async/utils/wrappers.py
Normal file
@@ -0,0 +1,86 @@
|
||||
class ThreadSafeTokenizerWrapper:
|
||||
def __init__(self, tokenizer, lock):
|
||||
self._tokenizer = tokenizer
|
||||
self._lock = lock
|
||||
|
||||
self._thread_safe_methods = {
|
||||
"__call__",
|
||||
"encode",
|
||||
"decode",
|
||||
"tokenize",
|
||||
"encode_plus",
|
||||
"batch_encode_plus",
|
||||
"batch_decode",
|
||||
}
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._tokenizer, name)
|
||||
|
||||
if name in self._thread_safe_methods and callable(attr):
|
||||
|
||||
def wrapped_method(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
return attr
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with self._lock:
|
||||
return self._tokenizer(*args, **kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._tokenizer, name, value)
|
||||
|
||||
def __dir__(self):
|
||||
return dir(self._tokenizer)
|
||||
|
||||
|
||||
class ThreadSafeVAEWrapper:
|
||||
def __init__(self, vae, lock):
|
||||
self._vae = vae
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._vae, name)
|
||||
if name in {"decode", "encode", "forward"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._vae, name, value)
|
||||
|
||||
|
||||
class ThreadSafeImageProcessorWrapper:
|
||||
def __init__(self, proc, lock):
|
||||
self._proc = proc
|
||||
self._lock = lock
|
||||
|
||||
def __getattr__(self, name):
|
||||
attr = getattr(self._proc, name)
|
||||
if name in {"postprocess", "preprocess"} and callable(attr):
|
||||
|
||||
def wrapped(*args, **kwargs):
|
||||
with self._lock:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
return attr
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
if name.startswith("_"):
|
||||
super().__setattr__(name, value)
|
||||
else:
|
||||
setattr(self._proc, name, value)
|
||||
886
scripts/convert_ltx2_to_diffusers.py
Normal file
886
scripts/convert_ltx2_to_diffusers.py
Normal file
@@ -0,0 +1,886 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upsample_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any(
|
||||
[
|
||||
args.vae,
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt,
|
||||
latent_upsampler_config,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -236,6 +238,7 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
@@ -353,6 +356,7 @@ else:
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"LCMScheduler",
|
||||
"LTXEulerAncestralRFScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
@@ -417,6 +421,8 @@ else:
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
@@ -537,7 +543,11 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
@@ -937,6 +947,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -980,6 +992,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
@@ -1088,6 +1101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
LTXEulerAncestralRFScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
@@ -1135,6 +1149,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
@@ -1251,7 +1267,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
|
||||
@@ -63,6 +63,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
|
||||
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
|
||||
|
||||
@@ -41,6 +41,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -104,6 +106,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -153,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -212,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -1106,6 +1106,51 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Not implemented yet.
|
||||
def _npu_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
|
||||
|
||||
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
@@ -1420,6 +1465,7 @@ def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1427,6 +1473,9 @@ def _flash_attention(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -1469,6 +1518,7 @@ def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1476,6 +1526,9 @@ def _flash_attention_hub(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1612,11 +1665,15 @@ def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -1636,6 +1693,7 @@ def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
@@ -1646,6 +1704,8 @@ def _flash_attention_3_hub(
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
@@ -1785,12 +1845,16 @@ def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
@@ -2028,6 +2092,7 @@ def _native_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -2035,6 +2100,9 @@ def _native_flash_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
lse = None
|
||||
if _parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2108,34 +2176,52 @@ def _native_math_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._NATIVE_NPU,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(1), # num_heads
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
if _parallel_config is None:
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
None,
|
||||
return_lse,
|
||||
forward_op=_npu_attention_forward_op,
|
||||
backward_op=_npu_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@@ -2148,10 +2234,13 @@ def _native_xla_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for XLA attention")
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2175,11 +2264,14 @@ def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
@@ -2223,11 +2315,14 @@ def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2309,11 +2404,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2333,11 +2431,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2357,11 +2458,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2381,11 +2485,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
|
||||
@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
self.up.insert(0, stage)
|
||||
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
if self.give_pre_end:
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE for encoding and decoding audio latent representations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetOutput(ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlNetModel(ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
mid_block_type=mid_block_type,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
class_embed_type=class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=num_class_embeds,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxControlNetOutput(FluxControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxControlNetModel(FluxControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
guidance_embeds=guidance_embeds,
|
||||
axes_dims_rope=axes_dims_rope,
|
||||
num_mode=num_mode,
|
||||
conditioning_embedding_channels=conditioning_embedding_channels,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(FluxMultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,68 +0,0 @@
|
||||
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3ControlNetOutput(SD3ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SD3ControlNetModel(SD3ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
sample_size=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
out_channels=out_channels,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
extra_conditioning_channels=extra_conditioning_channels,
|
||||
)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(SD3MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
SparseControlNetConditioningEmbedding,
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SparseControlNetOutput(SparseControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
|
||||
deprecate(
|
||||
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetModel(SparseControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
motion_max_seq_length=motion_max_seq_length,
|
||||
motion_num_attention_heads=motion_num_attention_heads,
|
||||
concat_conditioning_mask=concat_conditioning_mask,
|
||||
use_simplified_condition_embedding=use_simplified_condition_embedding,
|
||||
)
|
||||
@@ -47,6 +47,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
tqdm_kwargs["disable"] = True
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
||||
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
@@ -59,11 +59,8 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import (
|
||||
PushToHubMixin,
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
shard_files = resolved_model_file
|
||||
if len(resolved_model_file) > 1:
|
||||
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
shard_tqdm_kwargs["disable"] = True
|
||||
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
|
||||
|
||||
for shard_file in shard_files:
|
||||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -835,14 +835,8 @@ class Flux2Transformer2DModel(
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
|
||||
@@ -312,7 +312,6 @@ class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
||||
pooled_projections = self.text_embedder(pooled_projection)
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
|
||||
token_replace_emb = None
|
||||
if self.image_condition_type == "token_replace":
|
||||
@@ -324,8 +323,9 @@ class HunyuanVideoConditionEmbedding(nn.Module):
|
||||
if self.guidance_embedder is not None:
|
||||
guidance_proj = self.time_proj(guidance)
|
||||
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype))
|
||||
conditioning = conditioning + guidance_emb
|
||||
|
||||
conditioning = timesteps_emb + guidance_emb + pooled_projections
|
||||
else:
|
||||
conditioning = timesteps_emb + pooled_projections
|
||||
return conditioning, token_replace_emb
|
||||
|
||||
|
||||
|
||||
@@ -165,9 +165,8 @@ class Kandinsky5TimeEmbeddings(nn.Module):
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, time):
|
||||
args = torch.outer(time, self.freqs.to(device=time.device))
|
||||
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
@@ -269,7 +268,6 @@ class Kandinsky5Modulation(nn.Module):
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
@@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel(
|
||||
"Kandinsky5TransformerEncoderBlock",
|
||||
"Kandinsky5TransformerDecoderBlock",
|
||||
]
|
||||
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel(
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
|
||||
|
||||
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -530,11 +530,7 @@ class OvisImageTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -63,6 +63,8 @@ else:
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageAutoBlocks",
|
||||
@@ -96,6 +98,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -160,7 +160,10 @@ class AutoOffloadStrategy:
|
||||
if len(hooks) == 0:
|
||||
return []
|
||||
|
||||
current_module_size = model.get_memory_footprint()
|
||||
try:
|
||||
current_module_size = model.get_memory_footprint()
|
||||
except AttributeError:
|
||||
raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.")
|
||||
|
||||
device_type = execution_device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
@@ -703,7 +706,20 @@ class ComponentsManager:
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Make sure to install accelerate to use auto_cpu_offload")
|
||||
|
||||
# TODO: add a warning if mem_get_info isn't available on `device`.
|
||||
if device is None:
|
||||
device = get_device()
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
|
||||
device_type = device.type
|
||||
device_module = getattr(torch, device_type, torch.cuda)
|
||||
if not hasattr(device_module, "mem_get_info"):
|
||||
raise NotImplementedError(
|
||||
f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}."
|
||||
)
|
||||
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"):
|
||||
@@ -711,11 +727,7 @@ class ComponentsManager:
|
||||
|
||||
self.disable_auto_cpu_offload()
|
||||
offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin)
|
||||
if device is None:
|
||||
device = get_device()
|
||||
device = torch.device(device)
|
||||
if device.index is None:
|
||||
device = torch.device(f"{device.type}:{0}")
|
||||
|
||||
all_hooks = []
|
||||
for name, component in self.components.items():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
|
||||
@@ -121,7 +121,7 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# Adapted from `QwenImageInputsDynamicStep`
|
||||
# Adapted from `QwenImageAdditionalInputsStep`
|
||||
class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
|
||||
@@ -68,6 +68,10 @@ class MellonParam:
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="image_latents", label="Image Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
return cls(
|
||||
@@ -89,6 +93,10 @@ class MellonParam:
|
||||
def embeddings(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="embeddings", label="Text Embeddings", type="embeddings", display=display)
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
return cls(
|
||||
@@ -168,6 +176,18 @@ class MellonParam:
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider")
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
"""
|
||||
@@ -178,6 +198,16 @@ class MellonParam:
|
||||
"""
|
||||
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
"""
|
||||
Image Encoder model info dict.
|
||||
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input")
|
||||
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
"""
|
||||
|
||||
@@ -62,6 +62,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
]
|
||||
)
|
||||
@@ -231,7 +232,7 @@ class BlockState:
|
||||
|
||||
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks
|
||||
|
||||
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
|
||||
@@ -527,9 +528,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the inputs.
|
||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
||||
`select_block` method to define the logic for selecting the block.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
@@ -539,12 +541,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
Attributes:
|
||||
block_classes: List of block classes to be used
|
||||
block_names: List of prefixes for each block
|
||||
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
|
||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
||||
"""
|
||||
|
||||
block_classes = []
|
||||
block_names = []
|
||||
block_trigger_inputs = []
|
||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
@@ -554,26 +557,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
if not (len(self.block_classes) == len(self.block_names)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
|
||||
)
|
||||
default_blocks = [t for t in self.block_trigger_inputs if t is None]
|
||||
# can only have 1 or 0 default block, and has to put in the last
|
||||
# the order of blocks matters here because the first block with matching trigger will be dispatched
|
||||
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
|
||||
# as long as mask is provided, it is inpaint; if only image is provided, it is img2img
|
||||
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
|
||||
if self.default_block_name is not None and self.default_block_name not in self.block_names:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
|
||||
"in block_trigger_inputs."
|
||||
f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
|
||||
)
|
||||
|
||||
# Map trigger inputs to block objects
|
||||
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
|
||||
self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
|
||||
self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
return next(iter(self.sub_blocks.values())).model_name
|
||||
@@ -602,8 +594,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
# no default block means this conditional block can be skipped entirely
|
||||
if self.default_block_name is None:
|
||||
return []
|
||||
|
||||
first_block = next(iter(self.sub_blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_inputs", set()))
|
||||
|
||||
@@ -614,7 +608,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
@@ -639,36 +632,9 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
# Find default block first (if any)
|
||||
|
||||
block = self.trigger_to_block_map.get(None)
|
||||
for input_name in self.block_trigger_inputs:
|
||||
if input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
|
||||
if block is None:
|
||||
logger.info(f"skipping auto block: {self.__class__.__name__}")
|
||||
return pipeline, state
|
||||
|
||||
try:
|
||||
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
|
||||
return block(pipeline, state)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"\nError in block: {block.__class__.__name__}\n"
|
||||
f"Error details: {str(e)}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
@@ -676,9 +642,8 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
# Check if current block has block_trigger_inputs
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
@@ -688,15 +653,57 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return trigger_values
|
||||
|
||||
trigger_inputs = set(self.block_trigger_inputs)
|
||||
trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
# Start with this block's block_trigger_inputs
|
||||
all_triggers = {t for t in self.block_trigger_inputs if t is not None}
|
||||
# Add nested triggers
|
||||
all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
|
||||
return trigger_inputs
|
||||
return all_triggers
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
"""All trigger inputs including from nested blocks."""
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
||||
for selecting the block.
|
||||
|
||||
Args:
|
||||
**kwargs: Trigger input names and their values from the state.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The name of the block to run, or None to use default/skip.
|
||||
"""
|
||||
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||
block_name = self.select_block(**trigger_kwargs)
|
||||
|
||||
if block_name is None:
|
||||
block_name = self.default_block_name
|
||||
|
||||
if block_name is None:
|
||||
logger.info(f"skipping conditional block: {self.__class__.__name__}")
|
||||
return pipeline, state
|
||||
|
||||
block = self.sub_blocks[block_name]
|
||||
|
||||
try:
|
||||
logger.info(f"Running block: {block.__class__.__name__}")
|
||||
return block(pipeline, state)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"\nError in block: {block.__class__.__name__}\n"
|
||||
f"Error details: {str(e)}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
@@ -708,7 +715,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -729,31 +736,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
|
||||
# Blocks section - moved to the end with simplified format
|
||||
# Blocks section
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
if name == self.default_block_name:
|
||||
addtional_str = " [default]"
|
||||
else:
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
addtional_str = ""
|
||||
blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split("\n")
|
||||
indented_desc = desc_lines[0]
|
||||
if len(desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
|
||||
block_desc_lines = block.description.split("\n")
|
||||
indented_desc = block_desc_lines[0]
|
||||
if len(block_desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
# Build the representation with conditional sections
|
||||
@@ -784,6 +780,35 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_block_name(self) -> Optional[str]:
|
||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
||||
if None in self.block_trigger_inputs:
|
||||
idx = self.block_trigger_inputs.index(None)
|
||||
return self.block_names[idx]
|
||||
return None
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""Select block based on which trigger input is present (not None)."""
|
||||
for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
|
||||
if trigger_input is not None and kwargs.get(trigger_input) is not None:
|
||||
return block_name
|
||||
return None
|
||||
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
|
||||
@@ -885,7 +910,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
|
||||
# ConditionalPipelineBlocks without default can be skipped
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
@@ -948,8 +974,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
Returns a set of all unique trigger input values found in the blocks.
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
@@ -957,9 +982,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
# Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
@@ -975,82 +999,84 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def _traverse_trigger_blocks(self, trigger_inputs):
|
||||
# Convert trigger_inputs to a set for easier manipulation
|
||||
active_triggers = set(trigger_inputs)
|
||||
def _traverse_trigger_blocks(self, active_inputs):
|
||||
"""
|
||||
Traverse blocks and select which ones would run given the active inputs.
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_triggers):
|
||||
Args:
|
||||
active_inputs: Dict of input names to values that are "present"
|
||||
|
||||
Returns:
|
||||
OrderedDict of block_name -> block that would execute
|
||||
"""
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# sequential(include loopsequential) or PipelineBlock
|
||||
if not hasattr(block, "block_trigger_inputs"):
|
||||
if block.sub_blocks:
|
||||
# sequential or LoopSequentialPipelineBlocks (keep traversing)
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||
if isinstance(block, ConditionalPipelineBlocks):
|
||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
||||
selected_block_name = block.select_block(**trigger_kwargs)
|
||||
|
||||
if selected_block_name is None:
|
||||
selected_block_name = block.default_block_name
|
||||
|
||||
if selected_block_name is None:
|
||||
return result_blocks
|
||||
|
||||
selected_block = block.sub_blocks[selected_block_name]
|
||||
|
||||
if selected_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = block
|
||||
# Add this block's output names to active triggers if defined
|
||||
if hasattr(block, "outputs"):
|
||||
active_triggers.update(out.name for out in block.outputs)
|
||||
result_blocks[block_name] = selected_block
|
||||
if hasattr(selected_block, "outputs"):
|
||||
for out in selected_block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
# auto
|
||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
else:
|
||||
# Find first block_trigger_input that matches any value in our active_triggers
|
||||
this_block = None
|
||||
for trigger_input in block.block_trigger_inputs:
|
||||
if trigger_input is not None and trigger_input in active_triggers:
|
||||
this_block = block.trigger_to_block_map[trigger_input]
|
||||
break
|
||||
|
||||
# If no matches found, try to get the default (None) block
|
||||
if this_block is None and None in block.block_trigger_inputs:
|
||||
this_block = block.trigger_to_block_map[None]
|
||||
|
||||
if this_block is not None:
|
||||
# sequential/auto (keep traversing)
|
||||
if this_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = this_block
|
||||
# Add this block's output names to active triggers if defined
|
||||
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
|
||||
if hasattr(this_block, "outputs"):
|
||||
active_triggers.update(out.name for out in this_block.outputs)
|
||||
result_blocks[block_name] = block
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
all_blocks = OrderedDict()
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(blocks_to_update)
|
||||
return all_blocks
|
||||
|
||||
def get_execution_blocks(self, *trigger_inputs):
|
||||
trigger_inputs_all = self.trigger_inputs
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
if trigger_inputs is not None:
|
||||
if not isinstance(trigger_inputs, (list, tuple, set)):
|
||||
trigger_inputs = [trigger_inputs]
|
||||
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
|
||||
if invalid_inputs:
|
||||
logger.warning(
|
||||
f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
|
||||
)
|
||||
trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
Pass any inputs that would be non-None at runtime.
|
||||
|
||||
if trigger_inputs is None:
|
||||
if None in trigger_inputs_all:
|
||||
trigger_inputs = [None]
|
||||
else:
|
||||
trigger_inputs = [trigger_inputs_all[0]]
|
||||
blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
|
||||
Returns:
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
|
||||
Example:
|
||||
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
|
||||
image=image)
|
||||
|
||||
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
|
||||
"""
|
||||
# Filter out None values
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1067,7 +1093,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
# Get first trigger input as example
|
||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
||||
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
|
||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -1091,22 +1117,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Blocks section - moved to the end with simplified format
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
else:
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
# show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split("\n")
|
||||
@@ -1230,15 +1242,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
if inp.name not in outputs and inp not in inputs:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
for input_param in inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
@@ -1295,6 +1301,14 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
# Validate that sub_blocks are only leaf blocks
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if block.sub_blocks:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
|
||||
f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
|
||||
"""
|
||||
|
||||
@@ -1,661 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..image_processor import PipelineImageInput
|
||||
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam(
|
||||
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
||||
),
|
||||
"prompt_2": InputParam(
|
||||
"prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
||||
),
|
||||
"negative_prompt": InputParam(
|
||||
"negative_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts not to guide the image generation",
|
||||
),
|
||||
"negative_prompt_2": InputParam(
|
||||
"negative_prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The negative prompt or prompts for text_encoder_2",
|
||||
),
|
||||
"cross_attention_kwargs": InputParam(
|
||||
"cross_attention_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="Kwargs dictionary passed to the AttentionProcessor",
|
||||
),
|
||||
"clip_skip": InputParam(
|
||||
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
||||
),
|
||||
"image": InputParam(
|
||||
"image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to modify for img2img or inpainting",
|
||||
),
|
||||
"mask_image": InputParam(
|
||||
"mask_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Mask image for inpainting, white pixels will be repainted",
|
||||
),
|
||||
"generator": InputParam(
|
||||
"generator",
|
||||
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
description="Generator(s) for deterministic generation",
|
||||
),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam(
|
||||
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
||||
),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
||||
),
|
||||
"timesteps": InputParam(
|
||||
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
||||
),
|
||||
"sigmas": InputParam(
|
||||
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
||||
),
|
||||
"denoising_end": InputParam(
|
||||
"denoising_end",
|
||||
type_hint=Optional[float],
|
||||
description="Fraction of denoising process to complete before termination",
|
||||
),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam(
|
||||
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
||||
),
|
||||
"denoising_start": InputParam(
|
||||
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
||||
),
|
||||
"padding_mask_crop": InputParam(
|
||||
"padding_mask_crop",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Size of margin in crop for image and mask",
|
||||
),
|
||||
"original_size": InputParam(
|
||||
"original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Original size of the image for SDXL's micro-conditioning",
|
||||
),
|
||||
"target_size": InputParam(
|
||||
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
||||
),
|
||||
"negative_original_size": InputParam(
|
||||
"negative_original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on image resolution",
|
||||
),
|
||||
"negative_target_size": InputParam(
|
||||
"negative_target_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on target resolution",
|
||||
),
|
||||
"crops_coords_top_left": InputParam(
|
||||
"crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Top-left coordinates for SDXL's micro-conditioning",
|
||||
),
|
||||
"negative_crops_coords_top_left": InputParam(
|
||||
"negative_crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Negative conditioning crop coordinates",
|
||||
),
|
||||
"aesthetic_score": InputParam(
|
||||
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
||||
),
|
||||
"negative_aesthetic_score": InputParam(
|
||||
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
||||
),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam(
|
||||
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
||||
),
|
||||
"ip_adapter_image": InputParam(
|
||||
"ip_adapter_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Image(s) to be used as IP adapter",
|
||||
),
|
||||
"control_image": InputParam(
|
||||
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
||||
),
|
||||
"control_guidance_start": InputParam(
|
||||
"control_guidance_start",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=0.0,
|
||||
description="When ControlNet starts applying",
|
||||
),
|
||||
"control_guidance_end": InputParam(
|
||||
"control_guidance_end",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="When ControlNet stops applying",
|
||||
),
|
||||
"controlnet_conditioning_scale": InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="Scale factor for ControlNet outputs",
|
||||
),
|
||||
"guess_mode": InputParam(
|
||||
"guess_mode",
|
||||
type_hint=bool,
|
||||
default=False,
|
||||
description="Enables ControlNet encoder to recognize input without prompts",
|
||||
),
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
required=True,
|
||||
description="Text embeddings used to guide image generation",
|
||||
),
|
||||
"negative_prompt_embeds": InputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": InputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": InputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
"image_latents": InputParam(
|
||||
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
||||
),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"add_time_ids": InputParam(
|
||||
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
||||
),
|
||||
"negative_add_time_ids": InputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": InputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
required=True,
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||
|
||||
|
||||
DEFAULT_PARAM_MAPS = {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "a bear sitting in a chair drinking a milkshake",
|
||||
"display": "textarea",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
"display": "textarea",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 1000,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"display": "random",
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TYPE_MAPS = {
|
||||
"int": {
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
},
|
||||
"float": {
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
},
|
||||
"str": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
"bool": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
},
|
||||
"image": {
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||
"text_encoders": ["text_encoder", "tokenizer"],
|
||||
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||
"prompt_embeddings": ["prompt_embeds"],
|
||||
}
|
||||
|
||||
|
||||
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
"""
|
||||
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
|
||||
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||
"""
|
||||
if name is None:
|
||||
return None
|
||||
for group_name, group_keys in group_params_keys.items():
|
||||
for group_key in group_keys:
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
"""
|
||||
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
|
||||
around a ModularPipelineBlocks object.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
config_name = "node_config.json"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
self.blocks = blocks
|
||||
|
||||
if label is None:
|
||||
label = self.blocks.__class__.__name__
|
||||
# blocks param name -> mellon param name
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||
input_params[inp.name] = param
|
||||
mellon_name = param.pop("name", inp.name)
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
elif get_group_name(inp.name):
|
||||
param = get_group_name(inp.name)
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).lower()
|
||||
else:
|
||||
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||
if type_key in type_str:
|
||||
param = type_param.copy()
|
||||
param["label"] = inp.name
|
||||
param["display"] = "input"
|
||||
break
|
||||
else:
|
||||
param = inp.name
|
||||
# add the param dict to the inp_params dict
|
||||
input_params[inp.name] = param
|
||||
|
||||
component_params = {}
|
||||
for comp in self.blocks.expected_components:
|
||||
param = kwargs.pop(comp.name, None)
|
||||
if param:
|
||||
component_params[comp.name] = param
|
||||
mellon_name = param.pop("name", comp.name)
|
||||
if mellon_name != comp.name:
|
||||
self.name_mapping[comp.name] = mellon_name
|
||||
continue
|
||||
|
||||
to_exclude = False
|
||||
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||
if exclude_key in comp.name:
|
||||
to_exclude = True
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
self.name_mapping[comp.name] = param
|
||||
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||
else:
|
||||
param = comp.name
|
||||
# add the param dict to the model_params dict
|
||||
component_params[comp.name] = param
|
||||
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
|
||||
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediate_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
output_params[out.name] = param
|
||||
mellon_name = param.pop("name", out.name)
|
||||
if mellon_name != out.name:
|
||||
self.name_mapping[out.name] = mellon_name
|
||||
continue
|
||||
|
||||
if out.name in DEFAULT_PARAM_MAPS:
|
||||
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||
param["display"] = "output"
|
||||
else:
|
||||
group_name = get_group_name(out.name)
|
||||
if group_name:
|
||||
param = group_name
|
||||
if out.name not in self.name_mapping:
|
||||
self.name_mapping[out.name] = param
|
||||
else:
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
"input_params": input_params,
|
||||
"component_params": component_params,
|
||||
"output_params": output_params,
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components_manager, collection=None):
|
||||
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
|
||||
self._components_manager = components_manager
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
node = {}
|
||||
node["label"] = self.config.label
|
||||
node["category"] = self.config.category
|
||||
|
||||
node_param = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
if inp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[inp_name]
|
||||
else:
|
||||
mellon_name = inp_name
|
||||
if isinstance(inp_param, str):
|
||||
param = {
|
||||
"label": inp_param,
|
||||
"type": inp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
else:
|
||||
mellon_name = comp_name
|
||||
if isinstance(comp_param, str):
|
||||
param = {
|
||||
"label": comp_param,
|
||||
"type": comp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
else:
|
||||
mellon_name = out_name
|
||||
if isinstance(out_param, str):
|
||||
param = {
|
||||
"label": out_param,
|
||||
"type": out_param,
|
||||
"display": "output",
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||
node["params"] = node_param
|
||||
return node
|
||||
|
||||
def save_mellon_config(self, file_path):
|
||||
"""
|
||||
Save the Mellon configuration to a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path where the JSON file will be saved
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_mellon_config(cls, file_path):
|
||||
"""
|
||||
Load a Mellon configuration from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||
|
||||
Returns:
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
params_components = {}
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
logger.debug(f"component: {comp_name}")
|
||||
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||
if mellon_comp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||
else:
|
||||
comp = kwargs.pop(mellon_comp_name)
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||
if mellon_inp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||
else:
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
self.pipeline.update_components(**params_components)
|
||||
output = self.pipeline(**params_run, output=return_output_names)
|
||||
return output
|
||||
@@ -21,27 +21,27 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
_import_structure["modular_blocks_qwenimage"] = [
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"EDIT_BLOCKS",
|
||||
"EDIT_INPAINT_BLOCKS",
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"EDIT_PLUS_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit"] = [
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_layered"] = [
|
||||
"LAYERED_AUTO_BLOCKS",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -51,28 +51,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
from .modular_blocks_qwenimage import (
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
EDIT_AUTO_BLOCKS,
|
||||
EDIT_BLOCKS,
|
||||
EDIT_INPAINT_BLOCKS,
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
EDIT_PLUS_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit import (
|
||||
EDIT_AUTO_BLOCKS,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit_plus import (
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_layered import (
|
||||
LAYERED_AUTO_BLOCKS,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
@@ -113,7 +113,9 @@ def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
# Prepare Latents steps
|
||||
# ====================
|
||||
# 1. PREPARE LATENTS
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@@ -207,6 +209,98 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="layers", default=4),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
vae_scale_factor=components.vae_scale_factor,
|
||||
)
|
||||
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# we can update the height and width here since it's used to generate the initial
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if block_state.latents is None:
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -351,7 +445,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# Set Timesteps steps
|
||||
# ====================
|
||||
# 2. SET TIMESTEPS
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
@@ -420,6 +516,64 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50, type_hint=int),
|
||||
InputParam("sigmas", type_hint=List[float]),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="timesteps", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# Layered-specific mu calculation
|
||||
base_seqlen = 256 * 256 / 16 / 16 # = 256
|
||||
mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5
|
||||
|
||||
# Default sigmas if not provided
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -493,7 +647,9 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# other inputs for denoiser
|
||||
# ====================
|
||||
# 3. OTHER INPUTS FOR DENOISER
|
||||
# ====================
|
||||
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
@@ -522,6 +678,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
@@ -589,6 +746,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
@@ -639,19 +797,64 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
|
||||
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
|
||||
"Should be placed after prepare_latents step."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True, type_hint=List[int]),
|
||||
InputParam(name="image_width", required=True, type_hint=List[int]),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the image latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
|
||||
# Edit Plus: image_height and image_width are lists
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
|
||||
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
|
||||
],
|
||||
]
|
||||
] * block_state.batch_size
|
||||
@@ -670,6 +873,87 @@ class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="layers", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The shapes of the image latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
type_hint=List[int],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
type_hint=List[int],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="additional_t_cond",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The additional t cond, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# All shapes are the same for Layered
|
||||
shape = (
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
|
||||
# layers+1 output shapes + 1 condition shape (all same)
|
||||
block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size
|
||||
|
||||
# txt_seq_lens
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -24,12 +24,13 @@ from ...models import AutoencoderKLQwenImage
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# after denoising loop (unpack latents)
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -71,6 +72,46 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("height", required=True, type_hint=int),
|
||||
InputParam("width", required=True, type_hint=int),
|
||||
InputParam("layers", required=True, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W)
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.layers,
|
||||
components.vae_scale_factor,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# decode step
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -135,6 +176,81 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Decode unpacked latents (B, C, layers+1, H, W) into layer images."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
# 1. VAE normalization
|
||||
latents = latents.to(components.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
|
||||
# 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W)
|
||||
b, c, f, h, w = latents.shape
|
||||
# 3. Remove first frame (composite), keep layers frames
|
||||
latents = latents[:, :, 1:]
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
|
||||
|
||||
# 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W)
|
||||
image = components.vae.decode(latents, return_dict=False)[0]
|
||||
image = image.squeeze(2)
|
||||
|
||||
# 5. Postprocess - returns flat list of B*layers images
|
||||
image = components.image_processor.postprocess(image, output_type=block_state.output_type)
|
||||
|
||||
# 6. Chunk into list per batch item
|
||||
images = []
|
||||
for bidx in range(b):
|
||||
images.append(image[bidx * f : (bidx + 1) * f])
|
||||
|
||||
block_state.images = images
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# postprocess the decoded images
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,7 +29,12 @@ from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# ====================
|
||||
# 1. LOOP STEPS (run at each denoising step)
|
||||
# ====================
|
||||
|
||||
|
||||
# loop step:before denoiser
|
||||
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -60,7 +66,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -185,6 +191,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step:denoiser
|
||||
class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -253,6 +260,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
additional_cond_kwargs = {}
|
||||
for field_name, field_value in block_state.denoiser_input_fields.items():
|
||||
if field_name in transformer_args and field_name not in guider_inputs:
|
||||
additional_cond_kwargs[field_name] = field_value
|
||||
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
@@ -264,7 +278,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
@@ -284,7 +297,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -351,6 +364,13 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
additional_cond_kwargs = {}
|
||||
for field_name, field_value in block_state.denoiser_input_fields.items():
|
||||
if field_name in transformer_args and field_name not in guider_inputs:
|
||||
additional_cond_kwargs[field_name] = field_value
|
||||
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
@@ -362,7 +382,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
@@ -384,6 +403,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step:after denoiser
|
||||
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -481,6 +501,9 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE LOOP WRAPPER: define the denoising loop logic
|
||||
# ====================
|
||||
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -537,8 +560,15 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
# ====================
|
||||
# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps
|
||||
# ====================
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image)
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
@@ -559,8 +589,9 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the inpainting denoising loops
|
||||
# Qwen Image (inpainting)
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
@@ -583,8 +614,9 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
# Qwen Image (text2image, image2image) with controlnet
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
@@ -607,8 +639,9 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
@@ -639,8 +672,9 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
# Qwen Image Edit (image2image)
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
@@ -661,7 +695,9 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting)
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
@@ -682,3 +718,26 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage Edit."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image)
|
||||
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports QwenImage Layered."
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@ import torch
|
||||
from ...models import QwenImageMultiControlNetModel
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
@@ -221,37 +221,16 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage: update height/width, expand batch, patchify."""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
|
||||
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to []. Examples: ["processed_mask_image"]
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
@@ -263,14 +242,12 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
@@ -279,11 +256,16 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
@@ -293,11 +275,9 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
@@ -306,26 +286,28 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
description="The image height calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
# 1. Calculate height/width from latents and update if not provided
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
@@ -335,7 +317,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# 2. Patchify
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
@@ -354,7 +336,6 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
@@ -368,63 +349,270 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step for Edit Plus that:\n"
|
||||
" 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
|
||||
" Height/width defaults to last image in the list."
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=List[int],
|
||||
description="The image heights calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=List[int],
|
||||
description="The image widths calculated from the image latents dimension",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# Each image latent can have different size in QwenImage Edit Plus.
|
||||
is_list = isinstance(image_latent_tensor, list)
|
||||
if not is_list:
|
||||
image_latent_tensor = [image_latent_tensor]
|
||||
|
||||
image_heights = []
|
||||
image_widths = []
|
||||
packed_image_latent_tensors = []
|
||||
|
||||
for img_latent_tensor in image_latent_tensor:
|
||||
for i, img_latent_tensor in enumerate(image_latent_tensor):
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||
image_heights.append(height)
|
||||
image_widths.append(width)
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# 2. Patchify
|
||||
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_name=f"{image_latent_input_name}[{i}]",
|
||||
input_tensor=img_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
packed_image_latent_tensors.append(img_latent_tensor)
|
||||
|
||||
# Concatenate all packed latents along dim=1
|
||||
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||
|
||||
# Output lists of heights/widths
|
||||
block_state.image_height = image_heights
|
||||
block_state.image_width = image_widths
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Default height/width from last image
|
||||
block_state.height = block_state.height or image_heights[-1]
|
||||
block_state.width = block_state.width or image_widths[-1]
|
||||
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# YiYi TODO: support define config default component from the ModularPipeline level.
|
||||
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step for Layered that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
description="The image height calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="The height of the image output"),
|
||||
OutputParam(name="width", type_hint=int, description="The width of the image output"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents and update if not provided
|
||||
# Layered latents are (B, layers, C, H, W)
|
||||
height = image_latent_tensor.shape[3] * components.vae_scale_factor
|
||||
width = image_latent_tensor.shape[4] * components.vae_scale_factor
|
||||
block_state.height = height
|
||||
block_state.width = width
|
||||
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify with layered pachifier
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,469 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageRoPEInputsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
QwenImageInpaintControlNetDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageControlNetVaeEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageAdditionalInputsStep,
|
||||
QwenImageControlNetInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
|
||||
" - Resizes the image to the target size, based on `height` and `width`.\n"
|
||||
" - Processes and updates `image` and `mask_image`.\n"
|
||||
" - Creates `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# optional controlnet vae encoder
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
|
||||
+ " - if `control_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the pachified latents `mask` based on the processedmask image.\n"
|
||||
)
|
||||
|
||||
|
||||
# assemble denoising steps
|
||||
|
||||
|
||||
# Qwen Image (text2image)
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# Qwen Image (image2image)
|
||||
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# Qwen Image (text2image) with controlnet
|
||||
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageInpaintControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# Qwen Image (image2image) with controlnet
|
||||
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# Auto denoise step for QwenImage
|
||||
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageCoreDenoiseStep,
|
||||
QwenImageInpaintCoreDenoiseStep,
|
||||
QwenImageImg2ImgCoreDenoiseStep,
|
||||
QwenImageControlNetCoreDenoiseStep,
|
||||
QwenImageControlNetInpaintCoreDenoiseStep,
|
||||
QwenImageControlNetImg2ImgCoreDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"text2image",
|
||||
"inpaint",
|
||||
"img2img",
|
||||
"controlnet_text2image",
|
||||
"controlnet_inpaint",
|
||||
"controlnet_img2img",
|
||||
]
|
||||
block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
|
||||
default_block_name = "text2image"
|
||||
|
||||
def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
|
||||
if control_image_latents is not None:
|
||||
if processed_mask_image is not None:
|
||||
return "controlnet_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "controlnet_img2img"
|
||||
else:
|
||||
return "controlnet_text2image"
|
||||
else:
|
||||
if processed_mask_image is not None:
|
||||
return "inpaint"
|
||||
elif image_latents is not None:
|
||||
return "img2img"
|
||||
else:
|
||||
return "text2image"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
|
||||
+ " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
|
||||
+ " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
|
||||
+ " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
|
||||
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# standard decode step works for most tasks except for inpaint
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
|
||||
|
||||
|
||||
# Auto decode step for QwenImage
|
||||
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
|
||||
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
@@ -0,0 +1,336 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
QwenImageEditInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditInpaintProcessImagesInputStep,
|
||||
QwenImageEditProcessImagesInputStep,
|
||||
QwenImageEditResizeStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditInpaintProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
|
||||
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
|
||||
" - process the resized image and mask image.\n"
|
||||
" - create image latents."
|
||||
)
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
|
||||
" - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
|
||||
" - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the patchified latents `mask` based on the processed mask image.\n"
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image) core denoise step
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting) core denoise step
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageEditInpaintPrepareLatentsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
|
||||
# Auto core denoise step for QwenImage Edit
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintCoreDenoiseStep,
|
||||
QwenImageEditCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
default_block_name = "edit"
|
||||
|
||||
def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
|
||||
if processed_mask_image is not None:
|
||||
return "edit_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "edit"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto core denoising step that selects the appropriate workflow based on inputs.\n"
|
||||
" - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
|
||||
" - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# Decode step (standard)
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
|
||||
|
||||
|
||||
# Auto decode step
|
||||
class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("denoise", QwenImageEditAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageEditAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
|
||||
"- for edit (img2img) generation, you need to provide `image`\n"
|
||||
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditPlusProcessImagesInputStep,
|
||||
QwenImageEditPlusResizeStep,
|
||||
QwenImageEditPlusTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageEditPlusAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
|
||||
QwenImageEditPlusTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
|
||||
QwenImageEditPlusProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"VAE encoder step that encodes image inputs into latent representations.\n"
|
||||
"Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
|
||||
" - Standardizes text embeddings batch size.\n"
|
||||
" - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
|
||||
" - Outputs lists of image_height/image_width for RoPE calculation.\n"
|
||||
" - Defaults height/width from last image in the list."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit Plus (image2image) core denoise step
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditPlusRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocesses the generated image."
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||
("denoise", QwenImageEditPlusCoreDenoiseStep()),
|
||||
("decode", QwenImageEditPlusDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
|
||||
"- `image` is required input (can be single image or list of images).\n"
|
||||
"- Each image is resized independently based on its own aspect ratio.\n"
|
||||
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
|
||||
)
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageLayeredPrepareLatentsStep,
|
||||
QwenImageLayeredRoPEInputsStep,
|
||||
QwenImageLayeredSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageLayeredAfterDenoiseStep,
|
||||
QwenImageLayeredDecoderStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageLayeredDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditProcessImagesInputStep,
|
||||
QwenImageLayeredGetImagePromptStep,
|
||||
QwenImageLayeredPermuteLatentsStep,
|
||||
QwenImageLayeredResizeStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageLayeredAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
QwenImageLayeredGetImagePromptStep(),
|
||||
QwenImageTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "get_image_prompt", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
QwenImageEditProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
QwenImageLayeredPermuteLatentsStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode", "permute"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the layered denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image) core denoise step
|
||||
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredInputStep(),
|
||||
QwenImageLayeredPrepareLatentsStep(),
|
||||
QwenImageLayeredSetTimestepsStep(),
|
||||
QwenImageLayeredRoPEInputsStep(),
|
||||
QwenImageLayeredDenoiseStep(),
|
||||
QwenImageLayeredAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Layered img2img task."
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
LAYERED_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageLayeredTextEncoderStep()),
|
||||
("vae_encoder", QwenImageLayeredVaeEncoderStep()),
|
||||
("denoise", QwenImageLayeredCoreDenoiseStep()),
|
||||
("decode", QwenImageLayeredDecoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = LAYERED_AUTO_BLOCKS.values()
|
||||
block_names = LAYERED_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."
|
||||
@@ -90,6 +90,88 @@ class QwenImagePachifier(ConfigMixin):
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageLayeredPachifier(ConfigMixin):
|
||||
"""
|
||||
A class to pack and unpack latents for QwenImage Layered.
|
||||
|
||||
Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W).
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, patch_size: int = 2):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
"""
|
||||
Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4).
|
||||
"""
|
||||
|
||||
if latents.ndim != 5:
|
||||
raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}")
|
||||
|
||||
batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
|
||||
raise ValueError(
|
||||
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
|
||||
)
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
layers,
|
||||
num_channels_latents,
|
||||
latent_height // patch_size,
|
||||
patch_size,
|
||||
latent_width // patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
layers * (latent_height // patch_size) * (latent_width // patch_size),
|
||||
num_channels_latents * patch_size * patch_size,
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8):
|
||||
"""
|
||||
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W).
|
||||
"""
|
||||
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
|
||||
|
||||
batch_size, _, channels = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
layers + 1,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
layers + 1,
|
||||
channels // (patch_size * patch_size),
|
||||
height,
|
||||
width,
|
||||
)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage.
|
||||
@@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageEditPlusAutoBlocks"
|
||||
|
||||
|
||||
class QwenImageLayeredModularPipeline(QwenImageModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for QwenImage-Layered.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageLayeredAutoBlocks"
|
||||
|
||||
121
src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
Normal file
121
src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Prompt templates for QwenImage pipelines.
|
||||
|
||||
This module centralizes all prompt templates used across different QwenImage pipeline variants:
|
||||
- QwenImage (base): Text-only encoding for text-to-image generation
|
||||
- QwenImage Edit: VL encoding with single image for image editing
|
||||
- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing
|
||||
- QwenImage Layered: Auto-captioning for image decomposition
|
||||
"""
|
||||
|
||||
# ============================================
|
||||
# QwenImage Base (text-only encoding)
|
||||
# ============================================
|
||||
# Used for text-to-image generation where only text prompt is encoded
|
||||
|
||||
QWENIMAGE_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the image by detailing the color, shape, size, texture, quantity, text, "
|
||||
"spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Edit (VL encoding with single image)
|
||||
# ============================================
|
||||
# Used for single-image editing where both image and text are encoded together
|
||||
|
||||
QWENIMAGE_EDIT_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
||||
"then explain how the user's text instruction should alter or modify the image. "
|
||||
"Generate a new image that meets the user's requirements while maintaining consistency "
|
||||
"with the original input where appropriate.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Edit Plus (VL encoding with multiple images)
|
||||
# ============================================
|
||||
# Used for multi-reference editing where multiple images and text are encoded together
|
||||
# The img_template is used to format each image in the prompt
|
||||
|
||||
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
||||
"then explain how the user's text instruction should alter or modify the image. "
|
||||
"Generate a new image that meets the user's requirements while maintaining consistency "
|
||||
"with the original input where appropriate.<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Layered (auto-captioning)
|
||||
# ============================================
|
||||
# Used for image decomposition where the VL model generates a caption from the input image
|
||||
# if no prompt is provided. These prompts instruct the model to describe the image in detail.
|
||||
|
||||
QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"# Image Annotator\n"
|
||||
"You are a professional image annotator. Please write an image caption based on the input image:\n"
|
||||
"1. Write the caption using natural, descriptive language without structured formats or rich text.\n"
|
||||
"2. Enrich caption details by including:\n"
|
||||
" - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n"
|
||||
" - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, "
|
||||
"attachment relations, action relations, comparative relations, causal relations, and so on\n"
|
||||
" - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n"
|
||||
" - Identify the text clearly visible in the image, without translation or explanation, "
|
||||
"and highlight it in the caption with quotation marks\n"
|
||||
"3. Maintain authenticity and accuracy:\n"
|
||||
" - Avoid generalizations\n"
|
||||
" - Describe all visible information in the image, while do not add information not explicitly shown in the image\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"# 图像标注器\n"
|
||||
"你是一个专业的图像标注器。请基于输入图像,撰写图注:\n"
|
||||
"1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n"
|
||||
"2. 通过加入以下内容,丰富图注细节:\n"
|
||||
" - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n"
|
||||
" - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n"
|
||||
" - 环境细节:例如天气、光照、颜色、纹理、气氛等\n"
|
||||
" - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n"
|
||||
"3. 保持真实性与准确性:\n"
|
||||
" - 不要使用笼统的描述\n"
|
||||
" - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
@@ -84,7 +84,7 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanImage2VideoVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "vae_image_encoder"]
|
||||
block_names = ["image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -142,7 +142,7 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
class WanFLF2VVaeImageEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "wan"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameVaeImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_image_encoder"]
|
||||
block_names = ["image_resize", "last_image_resize", "vae_encoder"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
@@ -203,7 +203,7 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
## vae encoder
|
||||
class WanAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [WanFLF2VVaeImageEncoderStep, WanImage2VideoVaeImageEncoderStep]
|
||||
block_names = ["flf2v_vae_image_encoder", "image2video_vae_image_encoder"]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
|
||||
@property
|
||||
@@ -251,7 +251,7 @@ class WanAutoBlocks(SequentialPipelineBlocks):
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"image_encoder",
|
||||
"vae_image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -353,7 +353,7 @@ class Wan22AutoBlocks(SequentialPipelineBlocks):
|
||||
]
|
||||
block_names = [
|
||||
"text_encoder",
|
||||
"vae_image_encoder",
|
||||
"vae_encoder",
|
||||
"denoise",
|
||||
"decode",
|
||||
]
|
||||
@@ -384,7 +384,7 @@ IMAGE2VIDEO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("image_encoder", WanImage2VideoImageEncoderStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -401,7 +401,7 @@ FLF2V_BLOCKS = InsertableDict(
|
||||
("image_resize", WanImageResizeStep),
|
||||
("last_image_resize", WanImageCropResizeStep),
|
||||
("image_encoder", WanFLF2VImageEncoderStep),
|
||||
("vae_image_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("vae_encoder", WanFLF2VVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("additional_inputs", WanAdditionalInputsStep(image_latent_inputs=["first_last_frame_latents"])),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
@@ -416,7 +416,7 @@ AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("image_encoder", WanAutoImageEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", WanAutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
@@ -438,7 +438,7 @@ TEXT2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("image_resize", WanImageResizeStep),
|
||||
("vae_image_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("vae_encoder", WanImage2VideoVaeImageEncoderStep),
|
||||
("input", WanTextInputStep),
|
||||
("set_timesteps", WanSetTimestepsStep),
|
||||
("prepare_latents", WanPrepareLatentsStep),
|
||||
@@ -450,7 +450,7 @@ IMAGE2VIDEO_BLOCKS_WAN22 = InsertableDict(
|
||||
AUTO_BLOCKS_WAN22 = InsertableDict(
|
||||
[
|
||||
("text_encoder", WanTextEncoderStep),
|
||||
("vae_image_encoder", WanAutoVaeImageEncoderStep),
|
||||
("vae_encoder", WanAutoVaeImageEncoderStep),
|
||||
("denoise", Wan22AutoDenoiseStep),
|
||||
("decode", WanImageVaeDecoderStep),
|
||||
]
|
||||
|
||||
@@ -288,7 +288,9 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -729,7 +731,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -887,7 +887,13 @@ class AllegroPipeline(DiffusionPipeline):
|
||||
prompt_embeds = prompt_embeds.unsqueeze(1) # b l d -> b 1 l d
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, timestep_device, timesteps
|
||||
)
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# 5. Prepare latents.
|
||||
|
||||
@@ -897,16 +897,20 @@ class AnimateDiffVideoToVideoPipeline(
|
||||
dtype = self.dtype
|
||||
|
||||
# 3. Prepare timesteps
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
if not enforce_inference_steps:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
else:
|
||||
denoising_inference_steps = int(num_inference_steps / strength)
|
||||
timesteps, denoising_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
||||
self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
|
||||
)
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
@@ -1100,16 +1100,20 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
||||
dtype = self.dtype
|
||||
|
||||
# 3. Prepare timesteps
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
if not enforce_inference_steps:
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
self.scheduler, num_inference_steps, timestep_device, timesteps, sigmas
|
||||
)
|
||||
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, timesteps, strength, device)
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
else:
|
||||
denoising_inference_steps = int(num_inference_steps / strength)
|
||||
timesteps, denoising_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, denoising_inference_steps, device, timesteps, sigmas
|
||||
self.scheduler, denoising_inference_steps, timestep_device, timesteps, sigmas
|
||||
)
|
||||
timesteps = timesteps[-num_inference_steps:]
|
||||
latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt)
|
||||
|
||||
@@ -586,7 +586,13 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
# 4. Prepare timesteps
|
||||
|
||||
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, timestep_device, sigmas=sigmas
|
||||
)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
|
||||
@@ -99,6 +99,7 @@ from .qwenimage import (
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImageLayeredPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import SanaPipeline
|
||||
@@ -202,6 +203,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("qwenimage", QwenImageImg2ImgPipeline),
|
||||
("qwenimage-edit", QwenImageEditPipeline),
|
||||
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
|
||||
("qwenimage-layered", QwenImageLayeredPipeline),
|
||||
("z-image", ZImageImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -664,7 +664,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
||||
if XLA_AVAILABLE:
|
||||
timestep_device = "cpu"
|
||||
else:
|
||||
timestep_device = device
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, timestep_device, timesteps
|
||||
)
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
# 5. Prepare latents
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user