mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-16 20:57:10 +08:00
Compare commits
27 Commits
paulinebm-
...
modular-cu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d5080d7c6 | ||
|
|
d36564f06a | ||
|
|
441b69eabf | ||
|
|
d568c9773f | ||
|
|
3981c955ce | ||
|
|
1903383e94 | ||
|
|
08f8b7af9a | ||
|
|
2f66edc880 | ||
|
|
be38f41f9f | ||
|
|
91e5134175 | ||
|
|
a812c87465 | ||
|
|
8b9f817ef5 | ||
|
|
b1f06b780a | ||
|
|
8600b4c10d | ||
|
|
c10bdd9b73 | ||
|
|
dab000e88b | ||
|
|
9fb6b89d49 | ||
|
|
6fb4c99f5a | ||
|
|
961b9b27d3 | ||
|
|
8f30bfff1f | ||
|
|
b4be29bda2 | ||
|
|
98479a94c2 | ||
|
|
ade1059ae2 | ||
|
|
41a6e86faf | ||
|
|
9b5a244653 | ||
|
|
417f6b2d33 | ||
|
|
e46354d2d0 |
4
.github/workflows/benchmark.yml
vendored
4
.github/workflows/benchmark.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -58,7 +58,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
4
.github/workflows/build_docker_images.yml
vendored
4
.github/workflows/build_docker_images.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Check out code
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Find Changed Dockerfiles
|
||||
id: file_changes
|
||||
@@ -99,7 +99,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
- name: Login to Docker Hub
|
||||
|
||||
4
.github/workflows/build_pr_documentation.yml
vendored
4
.github/workflows/build_pr_documentation.yml
vendored
@@ -17,10 +17,10 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
|
||||
22
.github/workflows/codeql.yml
vendored
Normal file
22
.github/workflows/codeql.yml
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
name: CodeQL Security Analysis For Github Actions
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: ["main"]
|
||||
workflow_dispatch:
|
||||
# pull_request:
|
||||
|
||||
jobs:
|
||||
codeql:
|
||||
name: CodeQL Analysis
|
||||
uses: huggingface/security-workflows/.github/workflows/codeql-reusable.yml@v1
|
||||
permissions:
|
||||
security-events: write
|
||||
packages: read
|
||||
actions: read
|
||||
contents: read
|
||||
with:
|
||||
languages: '["actions","python"]'
|
||||
queries: 'security-extended,security-and-quality'
|
||||
runner: 'ubuntu-latest' #optional if need custom runner
|
||||
@@ -38,6 +38,7 @@ jobs:
|
||||
# If ref is 'refs/heads/main' => set 'main'
|
||||
# Else it must be a tag => set {tag}
|
||||
- name: Set checkout_ref and path_in_repo
|
||||
env:
|
||||
EVENT_NAME: ${{ github.event_name }}
|
||||
EVENT_INPUT_REF: ${{ github.event.inputs.ref }}
|
||||
GITHUB_REF: ${{ github.ref }}
|
||||
@@ -65,13 +66,13 @@ jobs:
|
||||
run: |
|
||||
echo "CHECKOUT_REF: ${{ env.CHECKOUT_REF }}"
|
||||
echo "PATH_IN_REPO: ${{ env.PATH_IN_REPO }}"
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ env.CHECKOUT_REF }}
|
||||
|
||||
# Setup + install dependencies
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
|
||||
46
.github/workflows/nightly_tests.yml
vendored
46
.github/workflows/nightly_tests.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -97,7 +97,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -119,7 +119,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file, examples]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -167,7 +167,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_${{ matrix.module }}_cuda_test_reports
|
||||
path: reports
|
||||
@@ -184,7 +184,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -263,7 +263,7 @@ jobs:
|
||||
cat reports/tests_big_gpu_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_big_gpu_test_reports
|
||||
path: reports
|
||||
@@ -280,7 +280,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -321,7 +321,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -355,7 +355,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -391,7 +391,7 @@ jobs:
|
||||
cat reports/tests_${{ matrix.config.backend }}_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.config.backend }}_reports
|
||||
path: reports
|
||||
@@ -408,7 +408,7 @@ jobs:
|
||||
options: --shm-size "20gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -441,7 +441,7 @@ jobs:
|
||||
cat reports/tests_pipeline_level_quant_torch_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
@@ -466,7 +466,7 @@ jobs:
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -474,7 +474,7 @@ jobs:
|
||||
run: mkdir -p combined_reports
|
||||
|
||||
- name: Download all test reports
|
||||
uses: actions/download-artifact@v4
|
||||
uses: actions/download-artifact@v7
|
||||
with:
|
||||
path: artifacts
|
||||
|
||||
@@ -500,7 +500,7 @@ jobs:
|
||||
cat $CONSOLIDATED_REPORT_PATH >> $GITHUB_STEP_SUMMARY
|
||||
|
||||
- name: Upload consolidated report
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: consolidated_test_report
|
||||
path: ${{ env.CONSOLIDATED_REPORT_PATH }}
|
||||
@@ -514,7 +514,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -554,7 +554,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v4
|
||||
# uses: actions/upload-artifact@v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
@@ -570,7 +570,7 @@ jobs:
|
||||
#
|
||||
# steps:
|
||||
# - name: Checkout diffusers
|
||||
# uses: actions/checkout@v3
|
||||
# uses: actions/checkout@v6
|
||||
# with:
|
||||
# fetch-depth: 2
|
||||
#
|
||||
@@ -610,7 +610,7 @@ jobs:
|
||||
#
|
||||
# - name: Test suite reports artifacts
|
||||
# if: ${{ always() }}
|
||||
# uses: actions/upload-artifact@v4
|
||||
# uses: actions/upload-artifact@v6
|
||||
# with:
|
||||
# name: torch_mps_test_reports
|
||||
# path: reports
|
||||
|
||||
@@ -10,10 +10,10 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
|
||||
4
.github/workflows/pr_dependency_test.yml
vendored
4
.github/workflows/pr_dependency_test.yml
vendored
@@ -18,9 +18,9 @@ jobs:
|
||||
check_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
38
.github/workflows/pr_modular_tests.yml
vendored
38
.github/workflows/pr_modular_tests.yml
vendored
@@ -1,3 +1,4 @@
|
||||
|
||||
name: Fast PR tests for Modular
|
||||
|
||||
on:
|
||||
@@ -35,9 +36,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -55,9 +56,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.10"
|
||||
- name: Install dependencies
|
||||
@@ -77,23 +78,13 @@ jobs:
|
||||
|
||||
run_fast_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
config:
|
||||
- name: Fast PyTorch Modular Pipeline CPU tests
|
||||
framework: pytorch_pipelines
|
||||
runner: aws-highmemory-32-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_modular_pipelines
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
name: Fast PyTorch Modular Pipeline CPU tests
|
||||
|
||||
runs-on:
|
||||
group: ${{ matrix.config.runner }}
|
||||
group: aws-highmemory-32-plus
|
||||
|
||||
container:
|
||||
image: ${{ matrix.config.image }}
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
|
||||
|
||||
defaults:
|
||||
@@ -102,7 +93,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -118,22 +109,19 @@ jobs:
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run fast PyTorch Pipeline CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_pipelines' }}
|
||||
run: |
|
||||
pytest -n 8 --max-worker-restart=0 --dist=loadfile \
|
||||
-k "not Flax and not Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
--make-reports=tests_torch_cpu_modular_pipelines \
|
||||
tests/modular_pipelines
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: cat reports/tests_${{ matrix.config.report }}_failures_short.txt
|
||||
run: cat reports/tests_torch_cpu_modular_pipelines_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
name: pr_pytorch_pipelines_torch_cpu_modular_pipelines_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
|
||||
12
.github/workflows/pr_test_fetcher.yml
vendored
12
.github/workflows/pr_test_fetcher.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
test_map: ${{ steps.set_matrix.outputs.test_map }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install dependencies
|
||||
@@ -42,7 +42,7 @@ jobs:
|
||||
run: |
|
||||
python utils/tests_fetcher.py | tee test_preparation.txt
|
||||
- name: Report fetched tests
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test_fetched
|
||||
path: test_preparation.txt
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -109,7 +109,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v3
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: ${{ matrix.modules }}_test_reports
|
||||
path: reports
|
||||
@@ -138,7 +138,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -164,7 +164,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
20
.github/workflows/pr_tests.yml
vendored
20
.github/workflows/pr_tests.yml
vendored
@@ -31,9 +31,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -51,9 +51,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -108,7 +108,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -153,7 +153,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.framework }}_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -185,7 +185,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -211,7 +211,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
@@ -236,7 +236,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -273,7 +273,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
path: reports
|
||||
|
||||
24
.github/workflows/pr_tests_gpu.yml
vendored
24
.github/workflows/pr_tests_gpu.yml
vendored
@@ -32,9 +32,9 @@ jobs:
|
||||
check_code_quality:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -52,9 +52,9 @@ jobs:
|
||||
needs: check_code_quality
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
@@ -83,7 +83,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -100,7 +100,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -170,7 +170,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -193,7 +193,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -239,7 +239,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -255,7 +255,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -287,7 +287,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
@@ -18,9 +18,9 @@ jobs:
|
||||
check_torch_dependencies:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
- name: Install dependencies
|
||||
|
||||
24
.github/workflows/push_tests.yml
vendored
24
.github/workflows/push_tests.yml
vendored
@@ -29,7 +29,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -46,7 +46,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -66,7 +66,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -98,7 +98,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -120,7 +120,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -155,7 +155,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -216,7 +216,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -240,7 +240,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -256,7 +256,7 @@ jobs:
|
||||
options: --gpus all --shm-size "16gb" --ipc host
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -286,7 +286,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_fast.yml
vendored
4
.github/workflows/push_tests_fast.yml
vendored
@@ -54,7 +54,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -88,7 +88,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_${{ matrix.config.report }}_test_reports
|
||||
path: reports
|
||||
|
||||
4
.github/workflows/push_tests_mps.yml
vendored
4
.github/workflows/push_tests_mps.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -65,7 +65,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_torch_mps_test_reports
|
||||
path: reports
|
||||
|
||||
8
.github/workflows/pypi_publish.yaml
vendored
8
.github/workflows/pypi_publish.yaml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
latest_branch: ${{ steps.set_latest_branch.outputs.latest_branch }}
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: '3.8'
|
||||
|
||||
@@ -40,12 +40,12 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout Repo
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: ${{ needs.find-and-checkout-latest-branch.outputs.latest_branch }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v4
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.8"
|
||||
|
||||
|
||||
28
.github/workflows/release_tests_fast.yml
vendored
28
.github/workflows/release_tests_fast.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
pipeline_test_matrix: ${{ steps.fetch_pipeline_matrix.outputs.pipeline_test_matrix }}
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: Install dependencies
|
||||
@@ -44,7 +44,7 @@ jobs:
|
||||
echo "pipeline_test_matrix=$matrix" >> $GITHUB_OUTPUT
|
||||
- name: Pipeline Tests Artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: test-pipelines.json
|
||||
path: reports
|
||||
@@ -64,7 +64,7 @@ jobs:
|
||||
options: --shm-size "16gb" --ipc host --gpus all
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
- name: NVIDIA-SMI
|
||||
@@ -94,7 +94,7 @@ jobs:
|
||||
cat reports/tests_pipeline_${{ matrix.module }}_cuda_failures_short.txt
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pipeline_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -116,7 +116,7 @@ jobs:
|
||||
module: [models, schedulers, lora, others, single_file]
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -149,7 +149,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_cuda_${{ matrix.module }}_test_reports
|
||||
path: reports
|
||||
@@ -166,7 +166,7 @@ jobs:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -205,7 +205,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
@@ -222,7 +222,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -247,7 +247,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_compile_test_reports
|
||||
path: reports
|
||||
@@ -264,7 +264,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -288,7 +288,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: torch_xformers_test_reports
|
||||
path: reports
|
||||
@@ -305,7 +305,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
@@ -336,7 +336,7 @@ jobs:
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
|
||||
2
.github/workflows/run_tests_from_a_pr.yml
vendored
2
.github/workflows/run_tests_from_a_pr.yml
vendored
@@ -57,7 +57,7 @@ jobs:
|
||||
shell: bash -e {0}
|
||||
|
||||
- name: Checkout PR branch
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
ref: refs/pull/${{ inputs.pr_number }}/head
|
||||
|
||||
|
||||
2
.github/workflows/ssh-pr-runner.yml
vendored
2
.github/workflows/ssh-pr-runner.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
2
.github/workflows/ssh-runner.yml
vendored
2
.github/workflows/ssh-runner.yml
vendored
@@ -35,7 +35,7 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
|
||||
4
.github/workflows/stale.yml
vendored
4
.github/workflows/stale.yml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v1
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: 3.8
|
||||
|
||||
|
||||
2
.github/workflows/trufflehog.yml
vendored
2
.github/workflows/trufflehog.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Secret Scanning
|
||||
|
||||
2
.github/workflows/typos.yml
vendored
2
.github/workflows/typos.yml
vendored
@@ -8,7 +8,7 @@ jobs:
|
||||
runs-on: ubuntu-22.04
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: typos-action
|
||||
uses: crate-ci/typos@v1.12.4
|
||||
|
||||
2
.github/workflows/update_metadata.yml
vendored
2
.github/workflows/update_metadata.yml
vendored
@@ -15,7 +15,7 @@ jobs:
|
||||
shell: bash -l {0}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- name: Setup environment
|
||||
run: |
|
||||
|
||||
@@ -367,6 +367,8 @@
|
||||
title: LatteTransformer3DModel
|
||||
- local: api/models/longcat_image_transformer2d
|
||||
title: LongCatImageTransformer2DModel
|
||||
- local: api/models/ltx2_video_transformer3d
|
||||
title: LTX2VideoTransformer3DModel
|
||||
- local: api/models/ltx_video_transformer3d
|
||||
title: LTXVideoTransformer3DModel
|
||||
- local: api/models/lumina2_transformer2d
|
||||
@@ -443,6 +445,10 @@
|
||||
title: AutoencoderKLHunyuanVideo
|
||||
- local: api/models/autoencoder_kl_hunyuan_video15
|
||||
title: AutoencoderKLHunyuanVideo15
|
||||
- local: api/models/autoencoderkl_audio_ltx_2
|
||||
title: AutoencoderKLLTX2Audio
|
||||
- local: api/models/autoencoderkl_ltx_2
|
||||
title: AutoencoderKLLTX2Video
|
||||
- local: api/models/autoencoderkl_ltx_video
|
||||
title: AutoencoderKLLTXVideo
|
||||
- local: api/models/autoencoderkl_magvit
|
||||
@@ -678,6 +684,8 @@
|
||||
title: Kandinsky 5.0 Video
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx2
|
||||
title: LTX-2
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
|
||||
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_audio_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Audio
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks. This is for encoding and decoding audio latent representations.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Audio
|
||||
|
||||
vae = AutoencoderKLLTX2Audio.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Audio
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Audio
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
29
docs/source/en/api/models/autoencoderkl_ltx_2.md
Normal file
@@ -0,0 +1,29 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# AutoencoderKLLTX2Video
|
||||
|
||||
The 3D variational autoencoder (VAE) model with KL loss used in [LTX-2](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderKLLTX2Video
|
||||
|
||||
vae = AutoencoderKLLTX2Video.from_pretrained("Lightricks/LTX-2", subfolder="vae", torch_dtype=torch.float32).to("cuda")
|
||||
```
|
||||
|
||||
## AutoencoderKLLTX2Video
|
||||
|
||||
[[autodoc]] AutoencoderKLLTX2Video
|
||||
- decode
|
||||
- encode
|
||||
- all
|
||||
@@ -42,4 +42,4 @@ pipe = FluxControlNetPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", co
|
||||
|
||||
## FluxControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_flux.FluxControlNetOutput
|
||||
[[autodoc]] models.controlnets.controlnet_flux.FluxControlNetOutput
|
||||
@@ -43,4 +43,4 @@ controlnet = SparseControlNetModel.from_pretrained("guoyww/animatediff-sparsectr
|
||||
|
||||
## SparseControlNetOutput
|
||||
|
||||
[[autodoc]] models.controlnet_sparsectrl.SparseControlNetOutput
|
||||
[[autodoc]] models.controlnets.controlnet_sparsectrl.SparseControlNetOutput
|
||||
|
||||
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
26
docs/source/en/api/models/ltx2_video_transformer3d.md
Normal file
@@ -0,0 +1,26 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# LTX2VideoTransformer3DModel
|
||||
|
||||
A Diffusion Transformer model for 3D data from [LTX](https://huggingface.co/Lightricks/LTX-2) was introduced by Lightricks.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import LTX2VideoTransformer3DModel
|
||||
|
||||
transformer = LTX2VideoTransformer3DModel.from_pretrained("Lightricks/LTX-2", subfolder="transformer", torch_dtype=torch.bfloat16).to("cuda")
|
||||
```
|
||||
|
||||
## LTX2VideoTransformer3DModel
|
||||
|
||||
[[autodoc]] LTX2VideoTransformer3DModel
|
||||
43
docs/source/en/api/pipelines/ltx2.md
Normal file
43
docs/source/en/api/pipelines/ltx2.md
Normal file
@@ -0,0 +1,43 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# LTX-2
|
||||
|
||||
LTX-2 is a DiT-based audio-video foundation model designed to generate synchronized video and audio within a single model. It brings together the core building blocks of modern video generation, with open weights and a focus on practical, local execution.
|
||||
|
||||
You can find all the original LTX-Video checkpoints under the [Lightricks](https://huggingface.co/Lightricks) organization.
|
||||
|
||||
The original codebase for LTX-2 can be found [here](https://github.com/Lightricks/LTX-2).
|
||||
|
||||
## LTX2Pipeline
|
||||
|
||||
[[autodoc]] LTX2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2ImageToVideoPipeline
|
||||
|
||||
[[autodoc]] LTX2ImageToVideoPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2LatentUpsamplePipeline
|
||||
|
||||
[[autodoc]] LTX2LatentUpsamplePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTX2PipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.ltx2.pipeline_output.LTX2PipelineOutput
|
||||
@@ -136,7 +136,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
- The recommended dtype for the transformer, VAE, and text encoder is `torch.bfloat16`. The VAE and text encoder can also be `torch.float32` or `torch.float16`.
|
||||
- For guidance-distilled variants of LTX-Video, set `guidance_scale` to `1.0`. The `guidance_scale` for any other model should be set higher, like `5.0`, for good generation quality.
|
||||
- For timestep-aware VAE variants (LTX-Video 0.9.1 and above), set `decode_timestep` to `0.05` and `image_cond_noise_scale` to `0.025`.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitionts in the generated video.
|
||||
- For variants that support interpolation between multiple conditioning images and videos (LTX-Video 0.9.5 and above), use similar images and videos for the best results. Divergence from the conditioning inputs may lead to abrupt transitions in the generated video.
|
||||
|
||||
- LTX-Video 0.9.7 includes a spatial latent upscaler and a 13B parameter transformer. During inference, a low resolution video is quickly generated first and then upscaled and refined.
|
||||
|
||||
@@ -329,7 +329,7 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
<details>
|
||||
<summary>Show example code</summary>
|
||||
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import LTXConditionPipeline, LTXLatentUpsamplePipeline
|
||||
@@ -474,6 +474,12 @@ export_to_video(video, "output.mp4", fps=24)
|
||||
|
||||
</details>
|
||||
|
||||
## LTXI2VLongMultiPromptPipeline
|
||||
|
||||
[[autodoc]] LTXI2VLongMultiPromptPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LTXPipeline
|
||||
|
||||
[[autodoc]] LTXPipeline
|
||||
|
||||
@@ -37,7 +37,8 @@ The following SkyReels-V2 models are supported in Diffusers:
|
||||
- [SkyReels-V2 I2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-1.3B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 540P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-540P-Diffusers)
|
||||
- [SkyReels-V2 I2V 14B - 720P](https://huggingface.co/Skywork/SkyReels-V2-I2V-14B-720P-Diffusers)
|
||||
- [SkyReels-V2 FLF2V 1.3B - 540P](https://huggingface.co/Skywork/SkyReels-V2-FLF2V-1.3B-540P-Diffusers)
|
||||
|
||||
This model was contributed by [M. Tolga Cangöz](https://github.com/tolgacangoz).
|
||||
|
||||
> [!TIP]
|
||||
> Click on the SkyReels-V2 models in the right sidebar for more examples of video generation.
|
||||
|
||||
@@ -250,9 +250,6 @@ The code snippets available in [this](https://github.com/huggingface/diffusers/p
|
||||
|
||||
The general rule of thumb to keep in mind when preparing inputs for the VACE pipeline is that the input images, or frames of a video that you want to use for conditioning, should have a corresponding mask that is black in color. The black mask signifies that the model will not generate new content for that area, and only use those parts for conditioning the generation process. For parts/frames that should be generated by the model, the mask should be white in color.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Wan-Animate: Unified Character Animation and Replacement with Holistic Replication
|
||||
|
||||
[Wan-Animate](https://huggingface.co/papers/2509.14055) by the Wan Team.
|
||||
|
||||
@@ -33,7 +33,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -50,7 +50,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
@@ -70,7 +70,7 @@ pipeline_quant_config = PipelineQuantizationConfig(
|
||||
)
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev",
|
||||
quantzation_config=pipeline_quant_config,
|
||||
quantization_config=pipeline_quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map="cuda"
|
||||
)
|
||||
|
||||
@@ -98,6 +98,9 @@ Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take
|
||||
This way, the text encoder model is not loaded into memory during training.
|
||||
> [!NOTE]
|
||||
> to enable remote text encoding you must either be logged in to your HuggingFace account (`hf auth login`) OR pass a token with `--hub_token`.
|
||||
### FSDP Text Encoder
|
||||
Flux.2 uses Mistral Small 3.1 as text encoder which is quite large and can take up a lot of memory. To mitigate this, we can use the `--fsdp_text_encoder` flag to enable distributed computation of the prompt embeddings.
|
||||
This way, it distributes the memory cost across multiple nodes.
|
||||
### CPU Offloading
|
||||
To offload parts of the model to CPU memory, you can use `--offload` flag. This will offload the vae and text encoder to CPU memory and only move them to GPU when needed.
|
||||
### Latent Caching
|
||||
@@ -166,6 +169,26 @@ To better track our training experiments, we're using the following flags in the
|
||||
> [!NOTE]
|
||||
> If you want to train using long prompts with the T5 text encoder, you can use `--max_sequence_length` to set the token limit. The default is 77, but it can be increased to as high as 512. Note that this will use more resources and may slow down the training in some cases.
|
||||
|
||||
### FSDP on the transformer
|
||||
By setting the accelerate configuration with FSDP, the transformer block will be wrapped automatically. E.g. set the configuration to:
|
||||
|
||||
```shell
|
||||
distributed_type: FSDP
|
||||
fsdp_config:
|
||||
fsdp_version: 2
|
||||
fsdp_offload_params: false
|
||||
fsdp_sharding_strategy: HYBRID_SHARD
|
||||
fsdp_auto_wrap_policy: TRANSFOMER_BASED_WRAP
|
||||
fsdp_transformer_layer_cls_to_wrap: Flux2TransformerBlock, Flux2SingleTransformerBlock
|
||||
fsdp_forward_prefetch: true
|
||||
fsdp_sync_module_states: false
|
||||
fsdp_state_dict_type: FULL_STATE_DICT
|
||||
fsdp_use_orig_params: false
|
||||
fsdp_activation_checkpointing: true
|
||||
fsdp_reshard_after_forward: true
|
||||
fsdp_cpu_ram_efficient_loading: false
|
||||
```
|
||||
|
||||
## LoRA + DreamBooth
|
||||
|
||||
[LoRA](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a popular parameter-efficient fine-tuning technique that allows you to achieve full-finetuning like performance but with a fraction of learnable parameters.
|
||||
|
||||
@@ -44,6 +44,7 @@ import shutil
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -75,13 +76,16 @@ from diffusers import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -722,6 +729,7 @@ def parse_args(input_args=None):
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1219,7 +1227,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1263,17 +1275,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1285,13 +1322,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1507,6 +1551,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1536,6 +1595,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1777,7 +1838,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1836,15 +1897,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
@@ -43,6 +43,7 @@ import random
|
||||
import shutil
|
||||
from contextlib import nullcontext
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -74,13 +75,16 @@ from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.flux2.image_processor import Flux2ImageProcessor
|
||||
from diffusers.training_utils import (
|
||||
_collate_lora_metadata,
|
||||
_to_cpu_contiguous,
|
||||
cast_training_params,
|
||||
compute_density_for_timestep_sampling,
|
||||
compute_loss_weighting_for_sd3,
|
||||
find_nearest_bucket,
|
||||
free_memory,
|
||||
get_fsdp_kwargs_from_accelerator,
|
||||
offload_models,
|
||||
parse_buckets_string,
|
||||
wrap_with_fsdp,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
@@ -93,6 +97,9 @@ from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
import torch.distributed as dist
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -339,7 +346,7 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
required=False,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -691,6 +698,7 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument("--enable_npu_flash_attention", action="store_true", help="Enabla Flash Attention for NPU")
|
||||
parser.add_argument("--fsdp_text_encoder", action="store_true", help="Use FSDP for text encoder")
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -827,15 +835,28 @@ class DreamBoothDataset(Dataset):
|
||||
dest_image = self.cond_images[i]
|
||||
image_width, image_height = dest_image.size
|
||||
if image_width * image_height > 1024 * 1024:
|
||||
dest_image = Flux2ImageProcessor.image_processor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
dest_image = Flux2ImageProcessor._resize_to_target_area(dest_image, 1024 * 1024)
|
||||
image_width, image_height = dest_image.size
|
||||
|
||||
multiple_of = 2 ** (4 - 1) # 2 ** (len(vae.config.block_out_channels) - 1), temp!
|
||||
image_width = (image_width // multiple_of) * multiple_of
|
||||
image_height = (image_height // multiple_of) * multiple_of
|
||||
dest_image = Flux2ImageProcessor.image_processor.preprocess(
|
||||
image_processor = Flux2ImageProcessor()
|
||||
dest_image = image_processor.preprocess(
|
||||
dest_image, height=image_height, width=image_width, resize_mode="crop"
|
||||
)
|
||||
# Convert back to PIL
|
||||
dest_image = dest_image.squeeze(0)
|
||||
if dest_image.min() < 0:
|
||||
dest_image = (dest_image + 1) / 2
|
||||
dest_image = (torch.clamp(dest_image, 0, 1) * 255).byte().cpu()
|
||||
|
||||
if dest_image.shape[0] == 1:
|
||||
# Gray scale image
|
||||
dest_image = Image.fromarray(dest_image.squeeze().numpy(), mode="L")
|
||||
else:
|
||||
# RGB scale image: (C, H, W) -> (H, W, C)
|
||||
dest_image = TF.to_pil_image(dest_image)
|
||||
|
||||
dest_image = exif_transpose(dest_image)
|
||||
if not dest_image.mode == "RGB":
|
||||
@@ -1156,7 +1177,11 @@ def main(args):
|
||||
if args.bnb_quantization_config_path is not None
|
||||
else {"device": accelerator.device, "dtype": weight_dtype}
|
||||
)
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
is_fsdp = accelerator.state.fsdp_plugin is not None
|
||||
if not is_fsdp:
|
||||
transformer.to(**transformer_to_kwargs)
|
||||
|
||||
if args.do_fp8_training:
|
||||
convert_to_float8_training(
|
||||
transformer, module_filter_fn=module_filter_fn, config=Float8LinearConfig(pad_inner_dim=True)
|
||||
@@ -1200,17 +1225,42 @@ def main(args):
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
modules_to_save = {}
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
transformer_cls = type(unwrap_model(transformer))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
# 1) Validate and pick the transformer model
|
||||
modules_to_save: dict[str, Any] = {}
|
||||
transformer_model = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(unwrap_model(model), transformer_cls):
|
||||
transformer_model = model
|
||||
modules_to_save["transformer"] = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
if transformer_model is None:
|
||||
raise ValueError("No transformer model found in 'models'")
|
||||
|
||||
# 2) Optionally gather FSDP state dict once
|
||||
state_dict = accelerator.get_state_dict(model) if is_fsdp else None
|
||||
|
||||
# 3) Only main process materializes the LoRA state dict
|
||||
transformer_lora_layers_to_save = None
|
||||
if accelerator.is_main_process:
|
||||
peft_kwargs = {}
|
||||
if is_fsdp:
|
||||
peft_kwargs["state_dict"] = state_dict
|
||||
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(
|
||||
unwrap_model(transformer_model) if is_fsdp else transformer_model,
|
||||
**peft_kwargs,
|
||||
)
|
||||
|
||||
if is_fsdp:
|
||||
transformer_lora_layers_to_save = _to_cpu_contiguous(transformer_lora_layers_to_save)
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
if weights:
|
||||
weights.pop()
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
@@ -1222,13 +1272,20 @@ def main(args):
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
if not is_fsdp:
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
|
||||
transformer_ = unwrap_model(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
else:
|
||||
transformer_ = Flux2Transformer2DModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="transformer",
|
||||
)
|
||||
transformer_.add_adapter(transformer_lora_config)
|
||||
|
||||
lora_state_dict = Flux2Pipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1419,9 +1476,9 @@ def main(args):
|
||||
args.instance_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.validation_prompt is not None:
|
||||
validation_image = load_image(args.validation_image_path).convert("RGB")
|
||||
validation_kwargs = {"image": validation_image}
|
||||
if args.remote_text_encoder:
|
||||
validation_kwargs["prompt_embeds"] = compute_remote_text_embeddings(args.validation_prompt)
|
||||
else:
|
||||
@@ -1430,6 +1487,21 @@ def main(args):
|
||||
args.validation_prompt, text_encoding_pipeline
|
||||
)
|
||||
|
||||
# Init FSDP for text encoder
|
||||
if args.fsdp_text_encoder:
|
||||
fsdp_kwargs = get_fsdp_kwargs_from_accelerator(accelerator)
|
||||
text_encoder_fsdp = wrap_with_fsdp(
|
||||
model=text_encoding_pipeline.text_encoder,
|
||||
device=accelerator.device,
|
||||
offload=args.offload,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=True,
|
||||
fsdp_kwargs=fsdp_kwargs,
|
||||
)
|
||||
|
||||
text_encoding_pipeline.text_encoder = text_encoder_fsdp
|
||||
dist.barrier()
|
||||
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
@@ -1461,6 +1533,8 @@ def main(args):
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if args.remote_text_encoder:
|
||||
prompt_embeds, text_ids = compute_remote_text_embeddings(batch["prompts"])
|
||||
elif args.fsdp_text_encoder:
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
else:
|
||||
with offload_models(text_encoding_pipeline, device=accelerator.device, offload=args.offload):
|
||||
prompt_embeds, text_ids = compute_text_embeddings(batch["prompts"], text_encoding_pipeline)
|
||||
@@ -1700,7 +1774,7 @@ def main(args):
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if accelerator.is_main_process or is_fsdp:
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
@@ -1759,15 +1833,41 @@ def main(args):
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_fsdp:
|
||||
transformer = unwrap_model(transformer)
|
||||
state_dict = accelerator.get_state_dict(transformer)
|
||||
if accelerator.is_main_process:
|
||||
modules_to_save = {}
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
if is_fsdp:
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
state_dict = {
|
||||
k: v.to(torch.float32) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
else:
|
||||
state_dict = {
|
||||
k: v.to(weight_dtype) if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
transformer_lora_layers = get_peft_model_state_dict(
|
||||
transformer,
|
||||
state_dict=state_dict,
|
||||
)
|
||||
transformer_lora_layers = {
|
||||
k: v.detach().cpu().contiguous() if isinstance(v, torch.Tensor) else v
|
||||
for k, v in transformer_lora_layers.items()
|
||||
}
|
||||
|
||||
else:
|
||||
transformer = unwrap_model(transformer)
|
||||
if args.bnb_quantization_config_path is None:
|
||||
if args.upcast_before_saving:
|
||||
transformer.to(torch.float32)
|
||||
else:
|
||||
transformer = transformer.to(weight_dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
modules_to_save["transformer"] = transformer
|
||||
|
||||
Flux2Pipeline.save_lora_weights(
|
||||
|
||||
886
scripts/convert_ltx2_to_diffusers.py
Normal file
886
scripts/convert_ltx2_to_diffusers.py
Normal file
@@ -0,0 +1,886 @@
|
||||
import argparse
|
||||
import os
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoTokenizer, Gemma3ForConditionalGeneration
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTX2VideoTransformer3DModel,
|
||||
)
|
||||
from diffusers.pipelines.ltx2 import LTX2LatentUpsamplerModel, LTX2TextConnectors, LTX2Vocoder
|
||||
from diffusers.utils.import_utils import is_accelerate_available
|
||||
|
||||
|
||||
CTX = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
# Input Patchify Projections
|
||||
"patchify_proj": "proj_in",
|
||||
"audio_patchify_proj": "audio_proj_in",
|
||||
# Modulation Parameters
|
||||
# Handle adaln_single --> time_embed, audioln_single --> audio_time_embed separately as the original keys are
|
||||
# substrings of the other modulation parameters below
|
||||
"av_ca_video_scale_shift_adaln_single": "av_cross_attn_video_scale_shift",
|
||||
"av_ca_a2v_gate_adaln_single": "av_cross_attn_video_a2v_gate",
|
||||
"av_ca_audio_scale_shift_adaln_single": "av_cross_attn_audio_scale_shift",
|
||||
"av_ca_v2a_gate_adaln_single": "av_cross_attn_audio_v2a_gate",
|
||||
# Transformer Blocks
|
||||
# Per-Block Cross Attention Modulatin Parameters
|
||||
"scale_shift_table_a2v_ca_video": "video_a2v_cross_attn_scale_shift_table",
|
||||
"scale_shift_table_a2v_ca_audio": "audio_a2v_cross_attn_scale_shift_table",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VIDEO_VAE_RENAME_DICT = {
|
||||
# Encoder
|
||||
"down_blocks.0": "down_blocks.0",
|
||||
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
||||
"down_blocks.2": "down_blocks.1",
|
||||
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
||||
"down_blocks.4": "down_blocks.2",
|
||||
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
||||
"down_blocks.6": "down_blocks.3",
|
||||
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
||||
"down_blocks.8": "mid_block",
|
||||
# Decoder
|
||||
"up_blocks.0": "mid_block",
|
||||
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
||||
"up_blocks.2": "up_blocks.0",
|
||||
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
||||
"up_blocks.4": "up_blocks.1",
|
||||
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
||||
"up_blocks.6": "up_blocks.2",
|
||||
# Common
|
||||
# For all 3D ResNets
|
||||
"res_blocks": "resnets",
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_RENAME_DICT = {
|
||||
"per_channel_statistics.mean-of-means": "latents_mean",
|
||||
"per_channel_statistics.std-of-means": "latents_std",
|
||||
}
|
||||
|
||||
LTX_2_0_VOCODER_RENAME_DICT = {
|
||||
"ups": "upsamplers",
|
||||
"resblocks": "resnets",
|
||||
"conv_pre": "conv_in",
|
||||
"conv_post": "conv_out",
|
||||
}
|
||||
|
||||
LTX_2_0_TEXT_ENCODER_RENAME_DICT = {
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_inplace(state_dict: Dict[str, Any], old_key: str, new_key: str) -> None:
|
||||
state_dict[new_key] = state_dict.pop(old_key)
|
||||
|
||||
|
||||
def remove_keys_inplace(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
state_dict.pop(key)
|
||||
|
||||
|
||||
def convert_ltx2_transformer_adaln_single(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
# Skip if not a weight, bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
if key.startswith("adaln_single."):
|
||||
new_key = key.replace("adaln_single.", "time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
if key.startswith("audio_adaln_single."):
|
||||
new_key = key.replace("audio_adaln_single.", "audio_time_embed.")
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae_per_channel_statistics(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
if key.startswith("per_channel_statistics"):
|
||||
new_key = ".".join(["decoder", key])
|
||||
param = state_dict.pop(key)
|
||||
state_dict[new_key] = param
|
||||
|
||||
return
|
||||
|
||||
|
||||
LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"video_embeddings_connector": remove_keys_inplace,
|
||||
"audio_embeddings_connector": remove_keys_inplace,
|
||||
"adaln_single": convert_ltx2_transformer_adaln_single,
|
||||
}
|
||||
|
||||
LTX_2_0_CONNECTORS_KEYS_RENAME_DICT = {
|
||||
"connectors.": "",
|
||||
"video_embeddings_connector": "video_connector",
|
||||
"audio_embeddings_connector": "audio_connector",
|
||||
"transformer_1d_blocks": "transformer_blocks",
|
||||
"text_embedding_projection.aggregate_embed": "text_proj_in",
|
||||
# Attention QK Norms
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
}
|
||||
|
||||
LTX_2_0_VAE_SPECIAL_KEYS_REMAP = {
|
||||
"per_channel_statistics.channel": remove_keys_inplace,
|
||||
"per_channel_statistics.mean-of-stds": remove_keys_inplace,
|
||||
}
|
||||
|
||||
LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def split_transformer_and_connector_state_dict(state_dict: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
connector_prefixes = (
|
||||
"video_embeddings_connector",
|
||||
"audio_embeddings_connector",
|
||||
"transformer_1d_blocks",
|
||||
"text_embedding_projection.aggregate_embed",
|
||||
"connectors.",
|
||||
"video_connector",
|
||||
"audio_connector",
|
||||
"text_proj_in",
|
||||
)
|
||||
|
||||
transformer_state_dict, connector_state_dict = {}, {}
|
||||
for key, value in state_dict.items():
|
||||
if key.startswith(connector_prefixes):
|
||||
connector_state_dict[key] = value
|
||||
else:
|
||||
transformer_state_dict[key] = value
|
||||
|
||||
return transformer_state_dict, connector_state_dict
|
||||
|
||||
|
||||
def get_ltx2_transformer_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
# Produces a transformer of the same size as used in test_models_transformer_ltx2.py
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 4,
|
||||
"out_channels": 4,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 8,
|
||||
"cross_attention_dim": 16,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 4,
|
||||
"audio_out_channels": 4,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 2,
|
||||
"audio_attention_head_dim": 4,
|
||||
"audio_cross_attention_dim": 8,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 2,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 16,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"out_channels": 128,
|
||||
"patch_size": 1,
|
||||
"patch_size_t": 1,
|
||||
"num_attention_heads": 32,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attention_dim": 4096,
|
||||
"vae_scale_factors": (8, 32, 32),
|
||||
"pos_embed_max_pos": 20,
|
||||
"base_height": 2048,
|
||||
"base_width": 2048,
|
||||
"audio_in_channels": 128,
|
||||
"audio_out_channels": 128,
|
||||
"audio_patch_size": 1,
|
||||
"audio_patch_size_t": 1,
|
||||
"audio_num_attention_heads": 32,
|
||||
"audio_attention_head_dim": 64,
|
||||
"audio_cross_attention_dim": 2048,
|
||||
"audio_scale_factor": 4,
|
||||
"audio_pos_embed_max_pos": 20,
|
||||
"audio_sampling_rate": 16000,
|
||||
"audio_hop_length": 160,
|
||||
"num_layers": 48,
|
||||
"activation_fn": "gelu-approximate",
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"norm_elementwise_affine": False,
|
||||
"norm_eps": 1e-6,
|
||||
"caption_channels": 3840,
|
||||
"attention_bias": True,
|
||||
"attention_out_bias": True,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_offset": 1,
|
||||
"timestep_scale_multiplier": 1000,
|
||||
"cross_attn_timestep_scale_multiplier": 1000,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def get_ltx2_connectors_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 16,
|
||||
"text_proj_in_factor": 3,
|
||||
"video_connector_num_attention_heads": 4,
|
||||
"video_connector_attention_head_dim": 8,
|
||||
"video_connector_num_layers": 1,
|
||||
"video_connector_num_learnable_registers": None,
|
||||
"audio_connector_num_attention_heads": 4,
|
||||
"audio_connector_attention_head_dim": 8,
|
||||
"audio_connector_num_layers": 1,
|
||||
"audio_connector_num_learnable_registers": None,
|
||||
"connector_rope_base_seq_len": 32,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": False,
|
||||
"causal_temporal_positioning": False,
|
||||
},
|
||||
}
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"caption_channels": 3840,
|
||||
"text_proj_in_factor": 49,
|
||||
"video_connector_num_attention_heads": 30,
|
||||
"video_connector_attention_head_dim": 128,
|
||||
"video_connector_num_layers": 2,
|
||||
"video_connector_num_learnable_registers": 128,
|
||||
"audio_connector_num_attention_heads": 30,
|
||||
"audio_connector_attention_head_dim": 128,
|
||||
"audio_connector_num_layers": 2,
|
||||
"audio_connector_num_learnable_registers": 128,
|
||||
"connector_rope_base_seq_len": 4096,
|
||||
"rope_theta": 10000.0,
|
||||
"rope_double_precision": True,
|
||||
"causal_temporal_positioning": False,
|
||||
"rope_type": "split",
|
||||
},
|
||||
}
|
||||
|
||||
rename_dict = LTX_2_0_CONNECTORS_KEYS_RENAME_DICT
|
||||
special_keys_remap = {}
|
||||
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_transformer(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_transformer_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
transformer_state_dict, _ = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = LTX2VideoTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(transformer_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(transformer_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, transformer_state_dict)
|
||||
|
||||
transformer.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
||||
return transformer
|
||||
|
||||
|
||||
def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -> LTX2TextConnectors:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_connectors_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
_, connector_state_dict = split_transformer_and_connector_state_dict(original_state_dict)
|
||||
if len(connector_state_dict) == 0:
|
||||
raise ValueError("No connector weights found in the provided state dict.")
|
||||
|
||||
with init_empty_weights():
|
||||
connectors = LTX2TextConnectors.from_config(diffusers_config)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(connector_state_dict, key, new_key)
|
||||
|
||||
for key in list(connector_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, connector_state_dict)
|
||||
|
||||
connectors.load_state_dict(connector_state_dict, strict=True, assign=True)
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
elif version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
"diffusers_config": {
|
||||
"in_channels": 3,
|
||||
"out_channels": 3,
|
||||
"latent_channels": 128,
|
||||
"block_out_channels": (256, 512, 1024, 2048),
|
||||
"down_block_types": (
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
"LTX2VideoDownBlock3D",
|
||||
),
|
||||
"decoder_block_out_channels": (256, 512, 1024),
|
||||
"layers_per_block": (4, 6, 6, 2, 2),
|
||||
"decoder_layers_per_block": (5, 5, 5, 5),
|
||||
"spatio_temporal_scaling": (True, True, True, True),
|
||||
"decoder_spatio_temporal_scaling": (True, True, True),
|
||||
"decoder_inject_noise": (False, False, False, False),
|
||||
"downsample_type": ("spatial", "temporal", "spatiotemporal", "spatiotemporal"),
|
||||
"upsample_residual": (True, True, True),
|
||||
"upsample_factor": (2, 2, 2),
|
||||
"timestep_conditioning": False,
|
||||
"patch_size": 4,
|
||||
"patch_size_t": 1,
|
||||
"resnet_norm_eps": 1e-6,
|
||||
"encoder_causal": True,
|
||||
"decoder_causal": False,
|
||||
"encoder_spatial_padding_mode": "zeros",
|
||||
"decoder_spatial_padding_mode": "reflect",
|
||||
"spatial_compression_ratio": 32,
|
||||
"temporal_compression_ratio": 8,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VIDEO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Video.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_audio_vae_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"base_channels": 128,
|
||||
"output_channels": 2,
|
||||
"ch_mult": (1, 2, 4),
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": None,
|
||||
"in_channels": 2,
|
||||
"resolution": 256,
|
||||
"latent_channels": 8,
|
||||
"norm_type": "pixel",
|
||||
"causality_axis": "height",
|
||||
"dropout": 0.0,
|
||||
"mid_block_add_attention": False,
|
||||
"sample_rate": 16000,
|
||||
"mel_hop_length": 160,
|
||||
"is_causal": True,
|
||||
"mel_bins": 64,
|
||||
"double_z": True,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_AUDIO_VAE_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_AUDIO_VAE_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_audio_vae(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_audio_vae_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vae = AutoencoderKLLTX2Audio.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vae.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vae
|
||||
|
||||
|
||||
def get_ltx2_vocoder_config(version: str) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/new-ltx-model",
|
||||
"diffusers_config": {
|
||||
"in_channels": 128,
|
||||
"hidden_channels": 1024,
|
||||
"out_channels": 2,
|
||||
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
|
||||
"upsample_factors": [6, 5, 2, 2, 2],
|
||||
"resnet_kernel_sizes": [3, 7, 11],
|
||||
"resnet_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
"leaky_relu_negative_slope": 0.1,
|
||||
"output_sampling_rate": 24000,
|
||||
},
|
||||
}
|
||||
rename_dict = LTX_2_0_VOCODER_RENAME_DICT
|
||||
special_keys_remap = LTX_2_0_VOCODER_SPECIAL_KEYS_REMAP
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_vocoder(original_state_dict: Dict[str, Any], version: str) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_vocoder_config(version)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
with init_empty_weights():
|
||||
vocoder = LTX2Vocoder.from_config(diffusers_config)
|
||||
|
||||
# Handle official code --> diffusers key remapping via the remap dict
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in rename_dict.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_inplace(original_state_dict, key, new_key)
|
||||
|
||||
# Handle any special logic which can't be expressed by a simple 1:1 remapping with the handlers in
|
||||
# special_keys_remap
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in special_keys_remap.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
vocoder.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
return vocoder
|
||||
|
||||
|
||||
def get_ltx2_spatial_latent_upsampler_config(version: str):
|
||||
if version == "2.0":
|
||||
config = {
|
||||
"in_channels": 128,
|
||||
"mid_channels": 1024,
|
||||
"num_blocks_per_stage": 4,
|
||||
"dims": 3,
|
||||
"spatial_upsample": True,
|
||||
"temporal_upsample": False,
|
||||
"rational_spatial_scale": 2.0,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported version: {version}")
|
||||
return config
|
||||
|
||||
|
||||
def convert_ltx2_spatial_latent_upsampler(
|
||||
original_state_dict: Dict[str, Any], config: Dict[str, Any], dtype: torch.dtype
|
||||
):
|
||||
with init_empty_weights():
|
||||
latent_upsampler = LTX2LatentUpsamplerModel(**config)
|
||||
|
||||
latent_upsampler.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
latent_upsampler.to(dtype)
|
||||
return latent_upsampler
|
||||
|
||||
|
||||
def load_original_checkpoint(args, filename: Optional[str]) -> Dict[str, Any]:
|
||||
if args.original_state_dict_repo_id is not None:
|
||||
ckpt_path = hf_hub_download(repo_id=args.original_state_dict_repo_id, filename=filename)
|
||||
elif args.checkpoint_path is not None:
|
||||
ckpt_path = args.checkpoint_path
|
||||
else:
|
||||
raise ValueError("Please provide either `original_state_dict_repo_id` or a local `checkpoint_path`")
|
||||
|
||||
original_state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
return original_state_dict
|
||||
|
||||
|
||||
def load_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None) -> Dict[str, Any]:
|
||||
if repo_id is None and filename is None:
|
||||
raise ValueError("Please supply at least one of `repo_id` or `filename`")
|
||||
|
||||
if repo_id is not None:
|
||||
if filename is None:
|
||||
raise ValueError("If repo_id is specified, filename must also be specified.")
|
||||
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
ckpt_path = filename
|
||||
|
||||
_, ext = os.path.splitext(ckpt_path)
|
||||
if ext in [".safetensors", ".sft"]:
|
||||
state_dict = safetensors.torch.load_file(ckpt_path)
|
||||
else:
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu")
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_model_state_dict_from_combined_ckpt(combined_ckpt: Dict[str, Any], prefix: str) -> Dict[str, Any]:
|
||||
# Ensure that the key prefix ends with a dot (.)
|
||||
if not prefix.endswith("."):
|
||||
prefix = prefix + "."
|
||||
|
||||
model_state_dict = {}
|
||||
for param_name, param in combined_ckpt.items():
|
||||
if param_name.startswith(prefix):
|
||||
model_state_dict[param_name.replace(prefix, "")] = param
|
||||
|
||||
if prefix == "model.diffusion_model.":
|
||||
# Some checkpoints store the text connector projection outside the diffusion model prefix.
|
||||
connector_key = "text_embedding_projection.aggregate_embed.weight"
|
||||
if connector_key in combined_ckpt and connector_key not in model_state_dict:
|
||||
model_state_dict[connector_key] = combined_ckpt[connector_key]
|
||||
|
||||
return model_state_dict
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--original_state_dict_repo_id",
|
||||
default="Lightricks/LTX-2",
|
||||
type=str,
|
||||
help="HF Hub repo id with LTX 2.0 checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
help="Local checkpoint path for LTX 2.0. Will be used if `original_state_dict_repo_id` is not specified.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--version",
|
||||
type=str,
|
||||
default="2.0",
|
||||
choices=["test", "2.0"],
|
||||
help="Version of the LTX 2.0 model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--combined_filename",
|
||||
default="ltx-2-19b-dev.safetensors",
|
||||
type=str,
|
||||
help="Filename for combined checkpoint with all LTX 2.0 models (VAE, DiT, etc.)",
|
||||
)
|
||||
parser.add_argument("--vae_prefix", default="vae.", type=str)
|
||||
parser.add_argument("--audio_vae_prefix", default="audio_vae.", type=str)
|
||||
parser.add_argument("--dit_prefix", default="model.diffusion_model.", type=str)
|
||||
parser.add_argument("--vocoder_prefix", default="vocoder.", type=str)
|
||||
|
||||
parser.add_argument("--vae_filename", default=None, type=str, help="VAE filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--audio_vae_filename", default=None, type=str, help="Audio VAE filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument("--dit_filename", default=None, type=str, help="DiT filename; overrides combined ckpt if set")
|
||||
parser.add_argument(
|
||||
"--vocoder_filename", default=None, type=str, help="Vocoder filename; overrides combined ckpt if set"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_model_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 base text encoder model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_id",
|
||||
default="google/gemma-3-12b-it-qat-q4_0-unquantized",
|
||||
type=str,
|
||||
help="HF Hub id for the LTX 2.0 text tokenizer",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--latent_upsampler_filename",
|
||||
default="ltx-2-spatial-upscaler-x2-1.0.safetensors",
|
||||
type=str,
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
parser.add_argument("--connectors", action="store_true", help="Whether to convert the connector model")
|
||||
parser.add_argument("--vocoder", action="store_true", help="Whether to convert the vocoder model")
|
||||
parser.add_argument("--text_encoder", action="store_true", help="Whether to conver the text encoder")
|
||||
parser.add_argument("--latent_upsampler", action="store_true", help="Whether to convert the latent upsampler")
|
||||
parser.add_argument(
|
||||
"--full_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save the pipeline. This will attempt to convert all models (e.g. vae, dit, etc.)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--upsample_pipeline",
|
||||
action="store_true",
|
||||
help="Whether to save a latent upsampling pipeline",
|
||||
)
|
||||
|
||||
parser.add_argument("--vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--audio_vae_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--dit_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--vocoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
parser.add_argument("--text_encoder_dtype", type=str, default="bf16", choices=["fp32", "fp16", "bf16"])
|
||||
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
DTYPE_MAPPING = {
|
||||
"fp32": torch.float32,
|
||||
"fp16": torch.float16,
|
||||
"bf16": torch.bfloat16,
|
||||
}
|
||||
|
||||
VARIANT_MAPPING = {
|
||||
"fp32": None,
|
||||
"fp16": "fp16",
|
||||
"bf16": "bf16",
|
||||
}
|
||||
|
||||
|
||||
def main(args):
|
||||
vae_dtype = DTYPE_MAPPING[args.vae_dtype]
|
||||
audio_vae_dtype = DTYPE_MAPPING[args.audio_vae_dtype]
|
||||
dit_dtype = DTYPE_MAPPING[args.dit_dtype]
|
||||
vocoder_dtype = DTYPE_MAPPING[args.vocoder_dtype]
|
||||
text_encoder_dtype = DTYPE_MAPPING[args.text_encoder_dtype]
|
||||
|
||||
combined_ckpt = None
|
||||
load_combined_models = any(
|
||||
[
|
||||
args.vae,
|
||||
args.audio_vae,
|
||||
args.dit,
|
||||
args.vocoder,
|
||||
args.text_encoder,
|
||||
args.full_pipeline,
|
||||
args.upsample_pipeline,
|
||||
]
|
||||
)
|
||||
if args.combined_filename is not None and load_combined_models:
|
||||
combined_ckpt = load_original_checkpoint(args, filename=args.combined_filename)
|
||||
|
||||
if args.vae or args.full_pipeline or args.upsample_pipeline:
|
||||
if args.vae_filename is not None:
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
if args.audio_vae or args.full_pipeline:
|
||||
if args.audio_vae_filename is not None:
|
||||
original_audio_vae_ckpt = load_hub_or_local_checkpoint(filename=args.audio_vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_audio_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.audio_vae_prefix)
|
||||
audio_vae = convert_ltx2_audio_vae(original_audio_vae_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
audio_vae.to(audio_vae_dtype).save_pretrained(os.path.join(args.output_path, "audio_vae"))
|
||||
|
||||
if args.dit or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_dit_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_dit_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
transformer = convert_ltx2_transformer(original_dit_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
transformer.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "transformer"))
|
||||
|
||||
if args.connectors or args.full_pipeline:
|
||||
if args.dit_filename is not None:
|
||||
original_connectors_ckpt = load_hub_or_local_checkpoint(filename=args.dit_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_connectors_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.dit_prefix)
|
||||
connectors = convert_ltx2_connectors(original_connectors_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
connectors.to(dit_dtype).save_pretrained(os.path.join(args.output_path, "connectors"))
|
||||
|
||||
if args.vocoder or args.full_pipeline:
|
||||
if args.vocoder_filename is not None:
|
||||
original_vocoder_ckpt = load_hub_or_local_checkpoint(filename=args.vocoder_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vocoder_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vocoder_prefix)
|
||||
vocoder = convert_ltx2_vocoder(original_vocoder_ckpt, version=args.version)
|
||||
if not args.full_pipeline:
|
||||
vocoder.to(vocoder_dtype).save_pretrained(os.path.join(args.output_path, "vocoder"))
|
||||
|
||||
if args.text_encoder or args.full_pipeline:
|
||||
# text_encoder = AutoModel.from_pretrained(args.text_encoder_model_id)
|
||||
text_encoder = Gemma3ForConditionalGeneration.from_pretrained(args.text_encoder_model_id)
|
||||
if not args.full_pipeline:
|
||||
text_encoder.to(text_encoder_dtype).save_pretrained(os.path.join(args.output_path, "text_encoder"))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_id)
|
||||
if not args.full_pipeline:
|
||||
tokenizer.save_pretrained(os.path.join(args.output_path, "tokenizer"))
|
||||
|
||||
if args.latent_upsampler or args.full_pipeline or args.upsample_pipeline:
|
||||
original_latent_upsampler_ckpt = load_hub_or_local_checkpoint(
|
||||
repo_id=args.original_state_dict_repo_id, filename=args.latent_upsampler_filename
|
||||
)
|
||||
latent_upsampler_config = get_ltx2_spatial_latent_upsampler_config(args.version)
|
||||
latent_upsampler = convert_ltx2_spatial_latent_upsampler(
|
||||
original_latent_upsampler_ckpt,
|
||||
latent_upsampler_config,
|
||||
dtype=vae_dtype,
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
latent_upsampler.save_pretrained(os.path.join(args.output_path, "latent_upsampler"))
|
||||
|
||||
if args.full_pipeline:
|
||||
scheduler = FlowMatchEulerDiscreteScheduler(
|
||||
use_dynamic_shifting=True,
|
||||
base_shift=0.95,
|
||||
max_shift=2.05,
|
||||
base_image_seq_len=1024,
|
||||
max_image_seq_len=4096,
|
||||
shift_terminal=0.1,
|
||||
)
|
||||
|
||||
pipe = LTX2Pipeline(
|
||||
scheduler=scheduler,
|
||||
vae=vae,
|
||||
audio_vae=audio_vae,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
connectors=connectors,
|
||||
transformer=transformer,
|
||||
vocoder=vocoder,
|
||||
)
|
||||
|
||||
pipe.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
|
||||
|
||||
if args.upsample_pipeline:
|
||||
pipe = LTX2LatentUpsamplePipeline(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
# Put latent upsampling pipeline in its own subdirectory so it doesn't mess with the full pipeline
|
||||
pipe.save_pretrained(
|
||||
os.path.join(args.output_path, "upsample_pipeline"), safe_serialization=True, max_shard_size="5GB"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
@@ -193,6 +193,8 @@ else:
|
||||
"AutoencoderKLHunyuanImageRefiner",
|
||||
"AutoencoderKLHunyuanVideo",
|
||||
"AutoencoderKLHunyuanVideo15",
|
||||
"AutoencoderKLLTX2Audio",
|
||||
"AutoencoderKLLTX2Video",
|
||||
"AutoencoderKLLTXVideo",
|
||||
"AutoencoderKLMagvit",
|
||||
"AutoencoderKLMochi",
|
||||
@@ -236,6 +238,7 @@ else:
|
||||
"Kandinsky5Transformer3DModel",
|
||||
"LatteTransformer3DModel",
|
||||
"LongCatImageTransformer2DModel",
|
||||
"LTX2VideoTransformer3DModel",
|
||||
"LTXVideoTransformer3DModel",
|
||||
"Lumina2Transformer2DModel",
|
||||
"LuminaNextDiT2DModel",
|
||||
@@ -353,6 +356,7 @@ else:
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"LCMScheduler",
|
||||
"LTXEulerAncestralRFScheduler",
|
||||
"PNDMScheduler",
|
||||
"RePaintScheduler",
|
||||
"SASolverScheduler",
|
||||
@@ -417,6 +421,8 @@ else:
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
@@ -537,7 +543,11 @@ else:
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ImageToVideoPipeline",
|
||||
"LTX2LatentUpsamplePipeline",
|
||||
"LTX2Pipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXPipeline",
|
||||
@@ -937,6 +947,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -980,6 +992,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
@@ -1088,6 +1101,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
KDPM2AncestralDiscreteScheduler,
|
||||
KDPM2DiscreteScheduler,
|
||||
LCMScheduler,
|
||||
LTXEulerAncestralRFScheduler,
|
||||
PNDMScheduler,
|
||||
RePaintScheduler,
|
||||
SASolverScheduler,
|
||||
@@ -1135,6 +1149,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
@@ -1251,7 +1267,11 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ImageToVideoPipeline,
|
||||
LTX2LatentUpsamplePipeline,
|
||||
LTX2Pipeline,
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
|
||||
@@ -41,6 +41,8 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
|
||||
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
|
||||
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
|
||||
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
|
||||
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
|
||||
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
|
||||
@@ -104,6 +106,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_kandinsky"] = ["Kandinsky5Transformer3DModel"]
|
||||
_import_structure["transformers.transformer_longcat_image"] = ["LongCatImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_ltx2"] = ["LTX2VideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
@@ -153,6 +156,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLHunyuanImageRefiner,
|
||||
AutoencoderKLHunyuanVideo,
|
||||
AutoencoderKLHunyuanVideo15,
|
||||
AutoencoderKLLTX2Audio,
|
||||
AutoencoderKLLTX2Video,
|
||||
AutoencoderKLLTXVideo,
|
||||
AutoencoderKLMagvit,
|
||||
AutoencoderKLMochi,
|
||||
@@ -212,6 +217,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Kandinsky5Transformer3DModel,
|
||||
LatteTransformer3DModel,
|
||||
LongCatImageTransformer2DModel,
|
||||
LTX2VideoTransformer3DModel,
|
||||
LTXVideoTransformer3DModel,
|
||||
Lumina2Transformer2DModel,
|
||||
LuminaNextDiT2DModel,
|
||||
|
||||
@@ -1106,6 +1106,51 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
):
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# Not implemented yet.
|
||||
def _npu_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for Npu Fusion Attention.")
|
||||
|
||||
|
||||
# ===== Context parallel =====
|
||||
|
||||
|
||||
@@ -1420,6 +1465,7 @@ def _flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1427,6 +1473,9 @@ def _flash_attention(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
if _parallel_config is None:
|
||||
out = flash_attn_func(
|
||||
q=query,
|
||||
@@ -1469,6 +1518,7 @@ def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -1476,6 +1526,9 @@ def _flash_attention_hub(
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
lse = None
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
@@ -1612,11 +1665,15 @@ def _flash_attention_3(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
out, lse = _wrapped_flash_attn_3(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -1636,6 +1693,7 @@ def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[float] = None,
|
||||
is_causal: bool = False,
|
||||
window_size: Tuple[int, int] = (-1, -1),
|
||||
@@ -1646,6 +1704,8 @@ def _flash_attention_3_hub(
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
@@ -1785,12 +1845,16 @@ def _aiter_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
if not return_lse and torch.is_grad_enabled():
|
||||
# aiter requires return_lse=True by assertion when gradients are enabled.
|
||||
out, lse, *_ = aiter_flash_attn_func(
|
||||
@@ -2028,6 +2092,7 @@ def _native_flash_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
@@ -2035,6 +2100,9 @@ def _native_flash_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for aiter attention")
|
||||
|
||||
lse = None
|
||||
if _parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2108,34 +2176,52 @@ def _native_math_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._NATIVE_NPU,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _native_npu_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
dropout_p: float = 0.0,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(1), # num_heads
|
||||
input_layout="BNSD",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
if _parallel_config is None:
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
pre_tockens=65536,
|
||||
next_tockens=65536,
|
||||
keep_prob=1.0 - dropout_p,
|
||||
sync=False,
|
||||
inner_precise=0,
|
||||
)[0]
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
None,
|
||||
return_lse,
|
||||
forward_op=_npu_attention_forward_op,
|
||||
backward_op=_npu_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@@ -2148,10 +2234,13 @@ def _native_xla_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for XLA attention")
|
||||
if return_lse:
|
||||
raise ValueError("XLA attention backend does not support setting `return_lse=True`.")
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
@@ -2175,11 +2264,14 @@ def _sage_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
if _parallel_config is None:
|
||||
out = sageattn(
|
||||
@@ -2223,11 +2315,14 @@ def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
lse = None
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
if _parallel_config is None:
|
||||
@@ -2309,11 +2404,14 @@ def _sage_qk_int8_pv_fp8_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2333,11 +2431,14 @@ def _sage_qk_int8_pv_fp8_cuda_sm90_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp8_cuda_sm90(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2357,11 +2458,14 @@ def _sage_qk_int8_pv_fp16_cuda_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_cuda(
|
||||
q=query,
|
||||
k=key,
|
||||
@@ -2381,11 +2485,14 @@ def _sage_qk_int8_pv_fp16_triton_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
is_causal: bool = False,
|
||||
scale: Optional[float] = None,
|
||||
return_lse: bool = False,
|
||||
_parallel_config: Optional["ParallelConfig"] = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for sage attention")
|
||||
return sageattn_qk_int8_pv_fp16_triton(
|
||||
q=query,
|
||||
k=key,
|
||||
|
||||
@@ -10,6 +10,8 @@ from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
|
||||
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
|
||||
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
|
||||
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
||||
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
|
||||
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio
|
||||
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
||||
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
||||
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
|
||||
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
1521
src/diffusers/models/autoencoders/autoencoder_kl_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
804
src/diffusers/models/autoencoders/autoencoder_kl_ltx2_audio.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ..modeling_outputs import AutoencoderKLOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
|
||||
|
||||
|
||||
LATENT_DOWNSAMPLE_FACTOR = 4
|
||||
|
||||
|
||||
class LTX2AudioCausalConv2d(nn.Module):
|
||||
"""
|
||||
A causal 2D convolution that pads asymmetrically along the causal axis.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: Union[int, Tuple[int, int]],
|
||||
stride: int = 1,
|
||||
dilation: Union[int, Tuple[int, int]] = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.causality_axis = causality_axis
|
||||
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
|
||||
dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
|
||||
|
||||
pad_h = (kernel_size[0] - 1) * dilation[0]
|
||||
pad_w = (kernel_size[1] - 1) * dilation[1]
|
||||
|
||||
if self.causality_axis == "none":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis in {"width", "width-compatibility"}:
|
||||
padding = (pad_w, 0, pad_h // 2, pad_h - pad_h // 2)
|
||||
elif self.causality_axis == "height":
|
||||
padding = (pad_w // 2, pad_w - pad_w // 2, pad_h, 0)
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {causality_axis}")
|
||||
|
||||
self.padding = padding
|
||||
self.conv = nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, self.padding)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class LTX2AudioPixelNorm(nn.Module):
|
||||
"""
|
||||
Per-pixel (per-location) RMS normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int = 1, eps: float = 1e-8) -> None:
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean_sq = torch.mean(x**2, dim=self.dim, keepdim=True)
|
||||
rms = torch.sqrt(mean_sq + self.eps)
|
||||
return x / rms
|
||||
|
||||
|
||||
class LTX2AudioAttnBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
norm_type: str = "group",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(x)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
batch, channels, height, width = q.shape
|
||||
q = q.reshape(batch, channels, height * width).permute(0, 2, 1).contiguous()
|
||||
k = k.reshape(batch, channels, height * width).contiguous()
|
||||
attn = torch.bmm(q, k) * (int(channels) ** (-0.5))
|
||||
attn = torch.nn.functional.softmax(attn, dim=2)
|
||||
|
||||
v = v.reshape(batch, channels, height * width)
|
||||
attn = attn.permute(0, 2, 1).contiguous()
|
||||
h_ = torch.bmm(v, attn).reshape(batch, channels, height, width)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
class LTX2AudioResnetBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
conv_shortcut: bool = False,
|
||||
dropout: float = 0.0,
|
||||
temb_channels: int = 512,
|
||||
norm_type: str = "group",
|
||||
causality_axis: str = "height",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.causality_axis is not None and self.causality_axis != "none" and norm_type == "group":
|
||||
raise ValueError("Causal ResnetBlock with GroupNorm is not supported.")
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
if norm_type == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm1 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
if causality_axis is not None:
|
||||
self.conv1 = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = nn.Linear(temb_channels, out_channels)
|
||||
if norm_type == "group":
|
||||
self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
|
||||
elif norm_type == "pixel":
|
||||
self.norm2 = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {norm_type}")
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
if causality_axis is not None:
|
||||
self.conv2 = LTX2AudioCausalConv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
if causality_axis is not None:
|
||||
self.conv_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
else:
|
||||
if causality_axis is not None:
|
||||
self.nin_shortcut = LTX2AudioCausalConv2d(
|
||||
in_channels, out_channels, kernel_size=1, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.nin_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
h = self.norm1(x)
|
||||
h = self.non_linearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(self.non_linearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = self.non_linearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
x = self.conv_shortcut(x) if self.use_conv_shortcut else self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class LTX2AudioDownsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.with_conv:
|
||||
# Padding tuple is in the order: (left, right, top, bottom).
|
||||
if self.causality_axis == "none":
|
||||
pad = (0, 1, 0, 1)
|
||||
elif self.causality_axis == "width":
|
||||
pad = (2, 0, 0, 1)
|
||||
elif self.causality_axis == "height":
|
||||
pad = (0, 1, 2, 0)
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pad = (1, 0, 0, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid `causality_axis` {self.causality_axis}; supported values are `none`, `width`, `height`,"
|
||||
f" and `width-compatibility`."
|
||||
)
|
||||
|
||||
x = F.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
# with_conv=False implies that causality_axis is "none"
|
||||
x = F.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioUpsample(nn.Module):
|
||||
def __init__(self, in_channels: int, with_conv: bool, causality_axis: Optional[str] = "height") -> None:
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
self.causality_axis = causality_axis
|
||||
if self.with_conv:
|
||||
if causality_axis is not None:
|
||||
self.conv = LTX2AudioCausalConv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, causality_axis=causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
if self.causality_axis is None or self.causality_axis == "none":
|
||||
pass
|
||||
elif self.causality_axis == "height":
|
||||
x = x[:, :, 1:, :]
|
||||
elif self.causality_axis == "width":
|
||||
x = x[:, :, :, 1:]
|
||||
elif self.causality_axis == "width-compatibility":
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"Invalid causality_axis: {self.causality_axis}")
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LTX2AudioAudioPatchifier:
|
||||
"""
|
||||
Patchifier for spectrogram/audio latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int,
|
||||
sample_rate: int = 16000,
|
||||
hop_length: int = 160,
|
||||
audio_latent_downsample_factor: int = 4,
|
||||
is_causal: bool = True,
|
||||
):
|
||||
self.hop_length = hop_length
|
||||
self.sample_rate = sample_rate
|
||||
self.audio_latent_downsample_factor = audio_latent_downsample_factor
|
||||
self.is_causal = is_causal
|
||||
self._patch_size = (1, patch_size, patch_size)
|
||||
|
||||
def patchify(self, audio_latents: torch.Tensor) -> torch.Tensor:
|
||||
batch, channels, time, freq = audio_latents.shape
|
||||
return audio_latents.permute(0, 2, 1, 3).reshape(batch, time, channels * freq)
|
||||
|
||||
def unpatchify(self, audio_latents: torch.Tensor, channels: int, mel_bins: int) -> torch.Tensor:
|
||||
batch, time, _ = audio_latents.shape
|
||||
return audio_latents.view(batch, time, channels, mel_bins).permute(0, 2, 1, 3)
|
||||
|
||||
@property
|
||||
def patch_size(self) -> Tuple[int, int, int]:
|
||||
return self._patch_size
|
||||
|
||||
|
||||
class LTX2AudioEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels
|
||||
base_resolution = resolution
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
in_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(in_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.down = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != self.num_resolutions - 1:
|
||||
stage.downsample = LTX2AudioDownsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res = curr_res // 2
|
||||
|
||||
self.down.append(stage)
|
||||
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(block_in, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
final_block_channels = block_in
|
||||
z_channels = 2 * latent_channels if double_z else latent_channels
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
self.non_linearity = nn.SiLU()
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, z_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, z_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# hidden_states expected shape: (batch_size, channels, time, num_mel_bins)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for level in range(self.num_resolutions):
|
||||
stage = self.down[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_states = block(hidden_states, temb=None)
|
||||
if stage.attn:
|
||||
hidden_states = stage.attn[block_idx](hidden_states)
|
||||
|
||||
if level != self.num_resolutions - 1 and hasattr(stage, "downsample"):
|
||||
hidden_states = stage.downsample(hidden_states)
|
||||
|
||||
hidden_states = self.mid.block_1(hidden_states, temb=None)
|
||||
hidden_states = self.mid.attn_1(hidden_states)
|
||||
hidden_states = self.mid.block_2(hidden_states, temb=None)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.non_linearity(hidden_states)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2AudioDecoder(nn.Module):
|
||||
"""
|
||||
Symmetric decoder that reconstructs audio spectrograms from latent features.
|
||||
|
||||
The decoder mirrors the encoder structure with configurable channel multipliers, attention resolutions, and causal
|
||||
convolutions.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 1,
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
norm_type: str = "group",
|
||||
causality_axis: Optional[str] = "width",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.sample_rate = sample_rate
|
||||
self.mel_hop_length = mel_hop_length
|
||||
self.is_causal = is_causal
|
||||
self.mel_bins = mel_bins
|
||||
self.patchifier = LTX2AudioAudioPatchifier(
|
||||
patch_size=1,
|
||||
audio_latent_downsample_factor=LATENT_DOWNSAMPLE_FACTOR,
|
||||
sample_rate=sample_rate,
|
||||
hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.out_ch = output_channels
|
||||
self.give_pre_end = False
|
||||
self.tanh_out = False
|
||||
self.norm_type = norm_type
|
||||
self.latent_channels = latent_channels
|
||||
self.channel_multipliers = ch_mult
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.causality_axis = causality_axis
|
||||
|
||||
base_block_channels = base_channels * self.channel_multipliers[-1]
|
||||
base_resolution = resolution // (2 ** (self.num_resolutions - 1))
|
||||
self.z_shape = (1, latent_channels, base_resolution, base_resolution)
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_in = LTX2AudioCausalConv2d(
|
||||
latent_channels, base_block_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_in = nn.Conv2d(latent_channels, base_block_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.non_linearity = nn.SiLU()
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
if mid_block_add_attention:
|
||||
self.mid.attn_1 = LTX2AudioAttnBlock(base_block_channels, norm_type=self.norm_type)
|
||||
else:
|
||||
self.mid.attn_1 = nn.Identity()
|
||||
self.mid.block_2 = LTX2AudioResnetBlock(
|
||||
in_channels=base_block_channels,
|
||||
out_channels=base_block_channels,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
|
||||
self.up = nn.ModuleList()
|
||||
block_in = base_block_channels
|
||||
curr_res = self.resolution // (2 ** (self.num_resolutions - 1))
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = nn.Module()
|
||||
stage.block = nn.ModuleList()
|
||||
stage.attn = nn.ModuleList()
|
||||
block_out = self.base_channels * self.channel_multipliers[level]
|
||||
|
||||
for _ in range(self.num_res_blocks + 1):
|
||||
stage.block.append(
|
||||
LTX2AudioResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout,
|
||||
norm_type=self.norm_type,
|
||||
causality_axis=self.causality_axis,
|
||||
)
|
||||
)
|
||||
block_in = block_out
|
||||
if self.attn_resolutions:
|
||||
if curr_res in self.attn_resolutions:
|
||||
stage.attn.append(LTX2AudioAttnBlock(block_in, norm_type=self.norm_type))
|
||||
|
||||
if level != 0:
|
||||
stage.upsample = LTX2AudioUpsample(block_in, True, causality_axis=self.causality_axis)
|
||||
curr_res *= 2
|
||||
|
||||
self.up.insert(0, stage)
|
||||
|
||||
final_block_channels = block_in
|
||||
|
||||
if self.norm_type == "group":
|
||||
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=final_block_channels, eps=1e-6, affine=True)
|
||||
elif self.norm_type == "pixel":
|
||||
self.norm_out = LTX2AudioPixelNorm(dim=1, eps=1e-6)
|
||||
else:
|
||||
raise ValueError(f"Invalid normalization type: {self.norm_type}")
|
||||
|
||||
if self.causality_axis is not None:
|
||||
self.conv_out = LTX2AudioCausalConv2d(
|
||||
final_block_channels, output_channels, kernel_size=3, stride=1, causality_axis=self.causality_axis
|
||||
)
|
||||
else:
|
||||
self.conv_out = nn.Conv2d(final_block_channels, output_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
_, _, frames, mel_bins = sample.shape
|
||||
|
||||
target_frames = frames * LATENT_DOWNSAMPLE_FACTOR
|
||||
|
||||
if self.causality_axis is not None:
|
||||
target_frames = max(target_frames - (LATENT_DOWNSAMPLE_FACTOR - 1), 1)
|
||||
|
||||
target_channels = self.out_ch
|
||||
target_mel_bins = self.mel_bins if self.mel_bins is not None else mel_bins
|
||||
|
||||
hidden_features = self.conv_in(sample)
|
||||
hidden_features = self.mid.block_1(hidden_features, temb=None)
|
||||
hidden_features = self.mid.attn_1(hidden_features)
|
||||
hidden_features = self.mid.block_2(hidden_features, temb=None)
|
||||
|
||||
for level in reversed(range(self.num_resolutions)):
|
||||
stage = self.up[level]
|
||||
for block_idx, block in enumerate(stage.block):
|
||||
hidden_features = block(hidden_features, temb=None)
|
||||
if stage.attn:
|
||||
hidden_features = stage.attn[block_idx](hidden_features)
|
||||
|
||||
if level != 0 and hasattr(stage, "upsample"):
|
||||
hidden_features = stage.upsample(hidden_features)
|
||||
|
||||
if self.give_pre_end:
|
||||
return hidden_features
|
||||
|
||||
hidden = self.norm_out(hidden_features)
|
||||
hidden = self.non_linearity(hidden)
|
||||
decoded_output = self.conv_out(hidden)
|
||||
decoded_output = torch.tanh(decoded_output) if self.tanh_out else decoded_output
|
||||
|
||||
_, _, current_time, current_freq = decoded_output.shape
|
||||
target_time = target_frames
|
||||
target_freq = target_mel_bins
|
||||
|
||||
decoded_output = decoded_output[
|
||||
:, :target_channels, : min(current_time, target_time), : min(current_freq, target_freq)
|
||||
]
|
||||
|
||||
time_padding_needed = target_time - decoded_output.shape[2]
|
||||
freq_padding_needed = target_freq - decoded_output.shape[3]
|
||||
|
||||
if time_padding_needed > 0 or freq_padding_needed > 0:
|
||||
padding = (
|
||||
0,
|
||||
max(freq_padding_needed, 0),
|
||||
0,
|
||||
max(time_padding_needed, 0),
|
||||
)
|
||||
decoded_output = F.pad(decoded_output, padding)
|
||||
|
||||
decoded_output = decoded_output[:, :target_channels, :target_time, :target_freq]
|
||||
|
||||
return decoded_output
|
||||
|
||||
|
||||
class AutoencoderKLLTX2Audio(ModelMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX2 audio VAE for encoding and decoding audio latent representations.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
base_channels: int = 128,
|
||||
output_channels: int = 2,
|
||||
ch_mult: Tuple[int, ...] = (1, 2, 4),
|
||||
num_res_blocks: int = 2,
|
||||
attn_resolutions: Optional[Tuple[int, ...]] = None,
|
||||
in_channels: int = 2,
|
||||
resolution: int = 256,
|
||||
latent_channels: int = 8,
|
||||
norm_type: str = "pixel",
|
||||
causality_axis: Optional[str] = "height",
|
||||
dropout: float = 0.0,
|
||||
mid_block_add_attention: bool = False,
|
||||
sample_rate: int = 16000,
|
||||
mel_hop_length: int = 160,
|
||||
is_causal: bool = True,
|
||||
mel_bins: Optional[int] = 64,
|
||||
double_z: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
supported_causality_axes = {"none", "width", "height", "width-compatibility"}
|
||||
if causality_axis not in supported_causality_axes:
|
||||
raise ValueError(f"{causality_axis=} is not valid. Supported values: {supported_causality_axes}")
|
||||
|
||||
attn_resolution_set = set(attn_resolutions) if attn_resolutions else attn_resolutions
|
||||
|
||||
self.encoder = LTX2AudioEncoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
double_z=double_z,
|
||||
)
|
||||
|
||||
self.decoder = LTX2AudioDecoder(
|
||||
base_channels=base_channels,
|
||||
output_channels=output_channels,
|
||||
ch_mult=ch_mult,
|
||||
num_res_blocks=num_res_blocks,
|
||||
attn_resolutions=attn_resolution_set,
|
||||
in_channels=in_channels,
|
||||
resolution=resolution,
|
||||
latent_channels=latent_channels,
|
||||
norm_type=norm_type,
|
||||
causality_axis=causality_axis,
|
||||
dropout=dropout,
|
||||
mid_block_add_attention=mid_block_add_attention,
|
||||
sample_rate=sample_rate,
|
||||
mel_hop_length=mel_hop_length,
|
||||
is_causal=is_causal,
|
||||
mel_bins=mel_bins,
|
||||
)
|
||||
|
||||
# Per-channel statistics for normalizing and denormalizing the latent representation. This statics is computed over
|
||||
# the entire dataset and stored in model's checkpoint under AudioVAE state_dict
|
||||
latents_std = torch.zeros((base_channels,))
|
||||
latents_mean = torch.ones((base_channels,))
|
||||
self.register_buffer("latents_mean", latents_mean, persistent=True)
|
||||
self.register_buffer("latents_std", latents_std, persistent=True)
|
||||
|
||||
# TODO: calculate programmatically instead of hardcoding
|
||||
self.temporal_compression_ratio = LATENT_DOWNSAMPLE_FACTOR # 4
|
||||
# TODO: confirm whether the mel compression ratio below is correct
|
||||
self.mel_compression_ratio = LATENT_DOWNSAMPLE_FACTOR
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.encoder(x)
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
return AutoencoderKLOutput(latent_dist=posterior)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
return self.decoder(z)
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
|
||||
decoded = torch.cat(decoded_slices)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.Tensor,
|
||||
sample_posterior: bool = False,
|
||||
return_dict: bool = True,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
) -> Union[DecoderOutput, torch.Tensor]:
|
||||
posterior = self.encode(sample).latent_dist
|
||||
if sample_posterior:
|
||||
z = posterior.sample(generator=generator)
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
if not return_dict:
|
||||
return (dec.sample,)
|
||||
return dec
|
||||
@@ -1,115 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate
|
||||
from .controlnets.controlnet import ( # noqa
|
||||
ControlNetConditioningEmbedding,
|
||||
ControlNetModel,
|
||||
ControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetOutput(ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetOutput` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class ControlNetModel(ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 3,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 1280,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
encoder_hid_dim: Optional[int] = None,
|
||||
encoder_hid_dim_type: Optional[str] = None,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
addition_embed_type: Optional[str] = None,
|
||||
addition_time_embed_dim: Optional[int] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
projection_class_embeddings_input_dim: Optional[int] = None,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
deprecation_message = "Importing `ControlNetModel` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
mid_block_type=mid_block_type,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
encoder_hid_dim=encoder_hid_dim,
|
||||
encoder_hid_dim_type=encoder_hid_dim_type,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
class_embed_type=class_embed_type,
|
||||
addition_embed_type=addition_embed_type,
|
||||
addition_time_embed_dim=addition_time_embed_dim,
|
||||
num_class_embeds=num_class_embeds,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
addition_embed_type_num_heads=addition_embed_type_num_heads,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetConditioningEmbedding(ControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `ControlNetConditioningEmbedding` from `diffusers.models.controlnet` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet import ControlNetConditioningEmbedding`, instead."
|
||||
deprecate("diffusers.models.controlnet.ControlNetConditioningEmbedding", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,70 +0,0 @@
|
||||
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import List
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class FluxControlNetOutput(FluxControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class FluxControlNetModel(FluxControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 1,
|
||||
in_channels: int = 64,
|
||||
num_layers: int = 19,
|
||||
num_single_layers: int = 38,
|
||||
attention_head_dim: int = 128,
|
||||
num_attention_heads: int = 24,
|
||||
joint_attention_dim: int = 4096,
|
||||
pooled_projection_dim: int = 768,
|
||||
guidance_embeds: bool = False,
|
||||
axes_dims_rope: List[int] = [16, 56, 56],
|
||||
num_mode: int = None,
|
||||
conditioning_embedding_channels: int = None,
|
||||
):
|
||||
deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
num_single_layers=num_single_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
guidance_embeds=guidance_embeds,
|
||||
axes_dims_rope=axes_dims_rope,
|
||||
num_mode=num_mode,
|
||||
conditioning_embedding_channels=conditioning_embedding_channels,
|
||||
)
|
||||
|
||||
|
||||
class FluxMultiControlNetModel(FluxMultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,68 +0,0 @@
|
||||
# Copyright 2025 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SD3ControlNetOutput(SD3ControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3ControlNetOutput` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SD3ControlNetModel(SD3ControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
sample_size: int = 128,
|
||||
patch_size: int = 2,
|
||||
in_channels: int = 16,
|
||||
num_layers: int = 18,
|
||||
attention_head_dim: int = 64,
|
||||
num_attention_heads: int = 18,
|
||||
joint_attention_dim: int = 4096,
|
||||
caption_projection_dim: int = 1152,
|
||||
pooled_projection_dim: int = 2048,
|
||||
out_channels: int = 16,
|
||||
pos_embed_max_size: int = 96,
|
||||
extra_conditioning_channels: int = 0,
|
||||
):
|
||||
deprecation_message = "Importing `SD3ControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3ControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3ControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
sample_size=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
num_layers=num_layers,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
joint_attention_dim=joint_attention_dim,
|
||||
caption_projection_dim=caption_projection_dim,
|
||||
pooled_projection_dim=pooled_projection_dim,
|
||||
out_channels=out_channels,
|
||||
pos_embed_max_size=pos_embed_max_size,
|
||||
extra_conditioning_channels=extra_conditioning_channels,
|
||||
)
|
||||
|
||||
|
||||
class SD3MultiControlNetModel(SD3MultiControlNetModel):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SD3MultiControlNetModel` from `diffusers.models.controlnet_sd3` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sd3 import SD3MultiControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sd3.SD3MultiControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -1,116 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
from ..utils import deprecate, logging
|
||||
from .controlnets.controlnet_sparsectrl import ( # noqa
|
||||
SparseControlNetConditioningEmbedding,
|
||||
SparseControlNetModel,
|
||||
SparseControlNetOutput,
|
||||
zero_module,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class SparseControlNetOutput(SparseControlNetOutput):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetOutput` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetOutput`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetOutput", "0.34", deprecation_message)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetConditioningEmbedding(SparseControlNetConditioningEmbedding):
|
||||
def __init__(self, *args, **kwargs):
|
||||
deprecation_message = "Importing `SparseControlNetConditioningEmbedding` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetConditioningEmbedding`, instead."
|
||||
deprecate(
|
||||
"diffusers.models.controlnet_sparsectrl.SparseControlNetConditioningEmbedding", "0.34", deprecation_message
|
||||
)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
class SparseControlNetModel(SparseControlNetModel):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 4,
|
||||
conditioning_channels: int = 4,
|
||||
flip_sin_to_cos: bool = True,
|
||||
freq_shift: int = 0,
|
||||
down_block_types: Tuple[str, ...] = (
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"CrossAttnDownBlockMotion",
|
||||
"DownBlockMotion",
|
||||
),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
|
||||
layers_per_block: int = 2,
|
||||
downsample_padding: int = 1,
|
||||
mid_block_scale_factor: float = 1,
|
||||
act_fn: str = "silu",
|
||||
norm_num_groups: Optional[int] = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
cross_attention_dim: int = 768,
|
||||
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
transformer_layers_per_mid_block: Optional[Union[int, Tuple[int]]] = None,
|
||||
temporal_transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
|
||||
attention_head_dim: Union[int, Tuple[int, ...]] = 8,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
|
||||
use_linear_projection: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
|
||||
global_pool_conditions: bool = False,
|
||||
controlnet_conditioning_channel_order: str = "rgb",
|
||||
motion_max_seq_length: int = 32,
|
||||
motion_num_attention_heads: int = 8,
|
||||
concat_conditioning_mask: bool = True,
|
||||
use_simplified_condition_embedding: bool = True,
|
||||
):
|
||||
deprecation_message = "Importing `SparseControlNetModel` from `diffusers.models.controlnet_sparsectrl` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_sparsectrl import SparseControlNetModel`, instead."
|
||||
deprecate("diffusers.models.controlnet_sparsectrl.SparseControlNetModel", "0.34", deprecation_message)
|
||||
super().__init__(
|
||||
in_channels=in_channels,
|
||||
conditioning_channels=conditioning_channels,
|
||||
flip_sin_to_cos=flip_sin_to_cos,
|
||||
freq_shift=freq_shift,
|
||||
down_block_types=down_block_types,
|
||||
only_cross_attention=only_cross_attention,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
downsample_padding=downsample_padding,
|
||||
mid_block_scale_factor=mid_block_scale_factor,
|
||||
act_fn=act_fn,
|
||||
norm_num_groups=norm_num_groups,
|
||||
norm_eps=norm_eps,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
transformer_layers_per_block=transformer_layers_per_block,
|
||||
transformer_layers_per_mid_block=transformer_layers_per_mid_block,
|
||||
temporal_transformer_layers_per_block=temporal_transformer_layers_per_block,
|
||||
attention_head_dim=attention_head_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
conditioning_embedding_out_channels=conditioning_embedding_out_channels,
|
||||
global_pool_conditions=global_pool_conditions,
|
||||
controlnet_conditioning_channel_order=controlnet_conditioning_channel_order,
|
||||
motion_max_seq_length=motion_max_seq_length,
|
||||
motion_num_attention_heads=motion_num_attention_heads,
|
||||
concat_conditioning_mask=concat_conditioning_mask,
|
||||
use_simplified_condition_embedding=use_simplified_condition_embedding,
|
||||
)
|
||||
@@ -47,6 +47,7 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -429,8 +430,12 @@ def _load_shard_files_with_threadpool(
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
tqdm_kwargs = {"total": len(shard_files), "desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
tqdm_kwargs["disable"] = True
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
||||
with logging.tqdm(**tqdm_kwargs) as pbar:
|
||||
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
||||
for future in as_completed(futures):
|
||||
result = future.result()
|
||||
|
||||
@@ -59,11 +59,8 @@ from ..utils import (
|
||||
is_torch_version,
|
||||
logging,
|
||||
)
|
||||
from ..utils.hub_utils import (
|
||||
PushToHubMixin,
|
||||
load_or_create_model_card,
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelConfig, ContextParallelModelPlan, ParallelConfig
|
||||
from .model_loading_utils import (
|
||||
@@ -1672,7 +1669,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
else:
|
||||
shard_files = resolved_model_file
|
||||
if len(resolved_model_file) > 1:
|
||||
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
||||
shard_tqdm_kwargs = {"desc": "Loading checkpoint shards"}
|
||||
if not is_torch_dist_rank_zero():
|
||||
shard_tqdm_kwargs["disable"] = True
|
||||
shard_files = logging.tqdm(resolved_model_file, **shard_tqdm_kwargs)
|
||||
|
||||
for shard_file in shard_files:
|
||||
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
||||
|
||||
@@ -35,6 +35,7 @@ if is_torch_available():
|
||||
from .transformer_kandinsky import Kandinsky5Transformer3DModel
|
||||
from .transformer_longcat_image import LongCatImageTransformer2DModel
|
||||
from .transformer_ltx import LTXVideoTransformer3DModel
|
||||
from .transformer_ltx2 import LTX2VideoTransformer3DModel
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
@@ -717,11 +717,7 @@ class FluxTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs:
|
||||
ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds")
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import USE_PEFT_BACKEND, is_torch_npu_available, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -835,14 +835,8 @@ class Flux2Transformer2DModel(
|
||||
if txt_ids.ndim == 3:
|
||||
txt_ids = txt_ids[0]
|
||||
|
||||
if is_torch_npu_available():
|
||||
freqs_cos_image, freqs_sin_image = self.pos_embed(img_ids.cpu())
|
||||
image_rotary_emb = (freqs_cos_image.npu(), freqs_sin_image.npu())
|
||||
freqs_cos_text, freqs_sin_text = self.pos_embed(txt_ids.cpu())
|
||||
text_rotary_emb = (freqs_cos_text.npu(), freqs_sin_text.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
image_rotary_emb = self.pos_embed(img_ids)
|
||||
text_rotary_emb = self.pos_embed(txt_ids)
|
||||
concat_rotary_emb = (
|
||||
torch.cat([text_rotary_emb[0], image_rotary_emb[0]], dim=0),
|
||||
torch.cat([text_rotary_emb[1], image_rotary_emb[1]], dim=0),
|
||||
|
||||
@@ -165,9 +165,8 @@ class Kandinsky5TimeEmbeddings(nn.Module):
|
||||
self.activation = nn.SiLU()
|
||||
self.out_layer = nn.Linear(time_dim, time_dim, bias=True)
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, time):
|
||||
args = torch.outer(time, self.freqs.to(device=time.device))
|
||||
args = torch.outer(time.to(torch.float32), self.freqs.to(device=time.device))
|
||||
time_embed = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
time_embed = self.out_layer(self.activation(self.in_layer(time_embed)))
|
||||
return time_embed
|
||||
@@ -269,7 +268,6 @@ class Kandinsky5Modulation(nn.Module):
|
||||
self.out_layer.weight.data.zero_()
|
||||
self.out_layer.bias.data.zero_()
|
||||
|
||||
@torch.autocast(device_type="cuda", dtype=torch.float32)
|
||||
def forward(self, x):
|
||||
return self.out_layer(self.activation(x))
|
||||
|
||||
@@ -525,6 +523,7 @@ class Kandinsky5Transformer3DModel(
|
||||
"Kandinsky5TransformerEncoderBlock",
|
||||
"Kandinsky5TransformerDecoderBlock",
|
||||
]
|
||||
_keep_in_fp32_modules = ["time_embeddings", "modulation", "visual_modulation", "text_modulation"]
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -499,11 +499,7 @@ class LongCatImageTransformer2DModel(
|
||||
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing and self.use_checkpoint[index_block]:
|
||||
|
||||
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
1350
src/diffusers/models/transformers/transformer_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -21,7 +21,7 @@ import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import is_torch_npu_available, logging
|
||||
from ...utils import logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
@@ -530,11 +530,7 @@ class OvisImageTransformer2DModel(
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
if is_torch_npu_available():
|
||||
freqs_cos, freqs_sin = self.pos_embed(ids.cpu())
|
||||
image_rotary_emb = (freqs_cos.npu(), freqs_sin.npu())
|
||||
else:
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
for index_block, block in enumerate(self.transformer_blocks):
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
|
||||
@@ -63,6 +63,8 @@ else:
|
||||
"QwenImageEditAutoBlocks",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["z_image"] = [
|
||||
"ZImageAutoBlocks",
|
||||
@@ -96,6 +98,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
from .stable_diffusion_xl import StableDiffusionXLAutoBlocks, StableDiffusionXLModularPipeline
|
||||
|
||||
@@ -121,7 +121,7 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# Adapted from `QwenImageInputsDynamicStep`
|
||||
# Adapted from `QwenImageAdditionalInputsStep`
|
||||
class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
|
||||
@@ -168,6 +168,14 @@ class MellonParam:
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider")
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
return cls(name="videos", label="Videos", type="video", display="output")
|
||||
|
||||
@classmethod
|
||||
def vae(cls) -> "MellonParam":
|
||||
"""
|
||||
|
||||
@@ -62,6 +62,7 @@ MODULAR_PIPELINE_MAPPING = OrderedDict(
|
||||
("qwenimage", "QwenImageModularPipeline"),
|
||||
("qwenimage-edit", "QwenImageEditModularPipeline"),
|
||||
("qwenimage-edit-plus", "QwenImageEditPlusModularPipeline"),
|
||||
("qwenimage-layered", "QwenImageLayeredModularPipeline"),
|
||||
("z-image", "ZImageModularPipeline"),
|
||||
]
|
||||
)
|
||||
@@ -231,7 +232,7 @@ class BlockState:
|
||||
|
||||
class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
Base class for all Pipeline Blocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
Base class for all Pipeline Blocks: ConditionalPipelineBlocks, AutoPipelineBlocks, SequentialPipelineBlocks,
|
||||
LoopSequentialPipelineBlocks
|
||||
|
||||
[`ModularPipelineBlocks`] provides method to load and save the definition of pipeline blocks.
|
||||
@@ -527,9 +528,10 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the inputs.
|
||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
||||
`select_block` method to define the logic for selecting the block.
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
@@ -539,12 +541,13 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
Attributes:
|
||||
block_classes: List of block classes to be used
|
||||
block_names: List of prefixes for each block
|
||||
block_trigger_inputs: List of input names that trigger specific blocks, with None for default
|
||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
||||
"""
|
||||
|
||||
block_classes = []
|
||||
block_names = []
|
||||
block_trigger_inputs = []
|
||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
@@ -554,26 +557,15 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
else:
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
if not (len(self.block_classes) == len(self.block_names)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
f"In {self.__class__.__name__}, the number of block_classes and block_names must be the same."
|
||||
)
|
||||
default_blocks = [t for t in self.block_trigger_inputs if t is None]
|
||||
# can only have 1 or 0 default block, and has to put in the last
|
||||
# the order of blocks matters here because the first block with matching trigger will be dispatched
|
||||
# e.g. blocks = [inpaint, img2img] and block_trigger_inputs = ["mask", "image"]
|
||||
# as long as mask is provided, it is inpaint; if only image is provided, it is img2img
|
||||
if len(default_blocks) > 1 or (len(default_blocks) == 1 and self.block_trigger_inputs[-1] is not None):
|
||||
if self.default_block_name is not None and self.default_block_name not in self.block_names:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, exactly one None must be specified as the last element "
|
||||
"in block_trigger_inputs."
|
||||
f"In {self.__class__.__name__}, default_block_name '{self.default_block_name}' must be one of block_names: {self.block_names}"
|
||||
)
|
||||
|
||||
# Map trigger inputs to block objects
|
||||
self.trigger_to_block_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.values()))
|
||||
self.trigger_to_block_name_map = dict(zip(self.block_trigger_inputs, self.sub_blocks.keys()))
|
||||
self.block_to_trigger_map = dict(zip(self.sub_blocks.keys(), self.block_trigger_inputs))
|
||||
|
||||
@property
|
||||
def model_name(self):
|
||||
return next(iter(self.sub_blocks.values())).model_name
|
||||
@@ -602,8 +594,10 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def required_inputs(self) -> List[str]:
|
||||
if None not in self.block_trigger_inputs:
|
||||
# no default block means this conditional block can be skipped entirely
|
||||
if self.default_block_name is None:
|
||||
return []
|
||||
|
||||
first_block = next(iter(self.sub_blocks.values()))
|
||||
required_by_all = set(getattr(first_block, "required_inputs", set()))
|
||||
|
||||
@@ -614,7 +608,6 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return list(required_by_all)
|
||||
|
||||
# YiYi TODO: add test for this
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
@@ -639,36 +632,9 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
# Find default block first (if any)
|
||||
|
||||
block = self.trigger_to_block_map.get(None)
|
||||
for input_name in self.block_trigger_inputs:
|
||||
if input_name is not None and state.get(input_name) is not None:
|
||||
block = self.trigger_to_block_map[input_name]
|
||||
break
|
||||
|
||||
if block is None:
|
||||
logger.info(f"skipping auto block: {self.__class__.__name__}")
|
||||
return pipeline, state
|
||||
|
||||
try:
|
||||
logger.info(f"Running block: {block.__class__.__name__}, trigger: {input_name}")
|
||||
return block(pipeline, state)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"\nError in block: {block.__class__.__name__}\n"
|
||||
f"Error details: {str(e)}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
@@ -676,9 +642,8 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
# Check if current block has block_trigger_inputs
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
@@ -688,15 +653,57 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return trigger_values
|
||||
|
||||
trigger_inputs = set(self.block_trigger_inputs)
|
||||
trigger_inputs.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
# Start with this block's block_trigger_inputs
|
||||
all_triggers = {t for t in self.block_trigger_inputs if t is not None}
|
||||
# Add nested triggers
|
||||
all_triggers.update(fn_recursive_get_trigger(self.sub_blocks))
|
||||
|
||||
return trigger_inputs
|
||||
return all_triggers
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
"""All trigger inputs including from nested blocks."""
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
||||
for selecting the block.
|
||||
|
||||
Args:
|
||||
**kwargs: Trigger input names and their values from the state.
|
||||
|
||||
Returns:
|
||||
Optional[str]: The name of the block to run, or None to use default/skip.
|
||||
"""
|
||||
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement the `select_block` method.")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
trigger_kwargs = {name: state.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||
block_name = self.select_block(**trigger_kwargs)
|
||||
|
||||
if block_name is None:
|
||||
block_name = self.default_block_name
|
||||
|
||||
if block_name is None:
|
||||
logger.info(f"skipping conditional block: {self.__class__.__name__}")
|
||||
return pipeline, state
|
||||
|
||||
block = self.sub_blocks[block_name]
|
||||
|
||||
try:
|
||||
logger.info(f"Running block: {block.__class__.__name__}")
|
||||
return block(pipeline, state)
|
||||
except Exception as e:
|
||||
error_msg = (
|
||||
f"\nError in block: {block.__class__.__name__}\n"
|
||||
f"Error details: {str(e)}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
@@ -708,7 +715,7 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -729,31 +736,20 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs = getattr(self, "expected_configs", [])
|
||||
configs_str = format_configs(expected_configs, indent_level=2, add_empty_lines=False)
|
||||
|
||||
# Blocks section - moved to the end with simplified format
|
||||
# Blocks section
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
if name == self.default_block_name:
|
||||
addtional_str = " [default]"
|
||||
else:
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
addtional_str = ""
|
||||
blocks_str += f" • {name}{addtional_str} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split("\n")
|
||||
indented_desc = desc_lines[0]
|
||||
if len(desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in desc_lines[1:])
|
||||
block_desc_lines = block.description.split("\n")
|
||||
indented_desc = block_desc_lines[0]
|
||||
if len(block_desc_lines) > 1:
|
||||
indented_desc += "\n" + "\n".join(" " + line for line in block_desc_lines[1:])
|
||||
blocks_str += f" Description: {indented_desc}\n\n"
|
||||
|
||||
# Build the representation with conditional sections
|
||||
@@ -784,6 +780,35 @@ class AutoPipelineBlocks(ModularPipelineBlocks):
|
||||
)
|
||||
|
||||
|
||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_block_name(self) -> Optional[str]:
|
||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
||||
if None in self.block_trigger_inputs:
|
||||
idx = self.block_trigger_inputs.index(None)
|
||||
return self.block_names[idx]
|
||||
return None
|
||||
|
||||
def select_block(self, **kwargs) -> Optional[str]:
|
||||
"""Select block based on which trigger input is present (not None)."""
|
||||
for trigger_input, block_name in zip(self.block_trigger_inputs, self.block_names):
|
||||
if trigger_input is not None and kwargs.get(trigger_input) is not None:
|
||||
return block_name
|
||||
return None
|
||||
|
||||
|
||||
class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that combines multiple pipeline block classes into one. When called, it will call each block in
|
||||
@@ -885,7 +910,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
if isinstance(block, ConditionalPipelineBlocks) and block.default_block_name is None:
|
||||
# ConditionalPipelineBlocks without default can be skipped
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
@@ -948,8 +974,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks. Returns: Set[str] containing all unique
|
||||
block_trigger_inputs values
|
||||
Returns a set of all unique trigger input values found in the blocks.
|
||||
"""
|
||||
|
||||
def fn_recursive_get_trigger(blocks):
|
||||
@@ -957,9 +982,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
if blocks is not None:
|
||||
for name, block in blocks.items():
|
||||
# Check if current block has trigger inputs(i.e. auto block)
|
||||
# Check if current block has block_trigger_inputs (ConditionalPipelineBlocks)
|
||||
if hasattr(block, "block_trigger_inputs") and block.block_trigger_inputs is not None:
|
||||
# Add all non-None values from the trigger inputs list
|
||||
trigger_values.update(t for t in block.block_trigger_inputs if t is not None)
|
||||
|
||||
# If block has sub_blocks, recursively check them
|
||||
@@ -975,82 +999,84 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def _traverse_trigger_blocks(self, trigger_inputs):
|
||||
# Convert trigger_inputs to a set for easier manipulation
|
||||
active_triggers = set(trigger_inputs)
|
||||
def _traverse_trigger_blocks(self, active_inputs):
|
||||
"""
|
||||
Traverse blocks and select which ones would run given the active inputs.
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_triggers):
|
||||
Args:
|
||||
active_inputs: Dict of input names to values that are "present"
|
||||
|
||||
Returns:
|
||||
OrderedDict of block_name -> block that would execute
|
||||
"""
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# sequential(include loopsequential) or PipelineBlock
|
||||
if not hasattr(block, "block_trigger_inputs"):
|
||||
if block.sub_blocks:
|
||||
# sequential or LoopSequentialPipelineBlocks (keep traversing)
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||
if isinstance(block, ConditionalPipelineBlocks):
|
||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
||||
selected_block_name = block.select_block(**trigger_kwargs)
|
||||
|
||||
if selected_block_name is None:
|
||||
selected_block_name = block.default_block_name
|
||||
|
||||
if selected_block_name is None:
|
||||
return result_blocks
|
||||
|
||||
selected_block = block.sub_blocks[selected_block_name]
|
||||
|
||||
if selected_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = block
|
||||
# Add this block's output names to active triggers if defined
|
||||
if hasattr(block, "outputs"):
|
||||
active_triggers.update(out.name for out in block.outputs)
|
||||
result_blocks[block_name] = selected_block
|
||||
if hasattr(selected_block, "outputs"):
|
||||
for out in selected_block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
# auto
|
||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
else:
|
||||
# Find first block_trigger_input that matches any value in our active_triggers
|
||||
this_block = None
|
||||
for trigger_input in block.block_trigger_inputs:
|
||||
if trigger_input is not None and trigger_input in active_triggers:
|
||||
this_block = block.trigger_to_block_map[trigger_input]
|
||||
break
|
||||
|
||||
# If no matches found, try to get the default (None) block
|
||||
if this_block is None and None in block.block_trigger_inputs:
|
||||
this_block = block.trigger_to_block_map[None]
|
||||
|
||||
if this_block is not None:
|
||||
# sequential/auto (keep traversing)
|
||||
if this_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = this_block
|
||||
# Add this block's output names to active triggers if defined
|
||||
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
|
||||
if hasattr(this_block, "outputs"):
|
||||
active_triggers.update(out.name for out in this_block.outputs)
|
||||
result_blocks[block_name] = block
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
all_blocks = OrderedDict()
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(blocks_to_update)
|
||||
return all_blocks
|
||||
|
||||
def get_execution_blocks(self, *trigger_inputs):
|
||||
trigger_inputs_all = self.trigger_inputs
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
if trigger_inputs is not None:
|
||||
if not isinstance(trigger_inputs, (list, tuple, set)):
|
||||
trigger_inputs = [trigger_inputs]
|
||||
invalid_inputs = [x for x in trigger_inputs if x not in trigger_inputs_all]
|
||||
if invalid_inputs:
|
||||
logger.warning(
|
||||
f"The following trigger inputs will be ignored as they are not supported: {invalid_inputs}"
|
||||
)
|
||||
trigger_inputs = [x for x in trigger_inputs if x in trigger_inputs_all]
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
Pass any inputs that would be non-None at runtime.
|
||||
|
||||
if trigger_inputs is None:
|
||||
if None in trigger_inputs_all:
|
||||
trigger_inputs = [None]
|
||||
else:
|
||||
trigger_inputs = [trigger_inputs_all[0]]
|
||||
blocks_triggered = self._traverse_trigger_blocks(trigger_inputs)
|
||||
Returns:
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
|
||||
Example:
|
||||
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
|
||||
image=image)
|
||||
|
||||
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
|
||||
"""
|
||||
# Filter out None values
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
||||
|
||||
def __repr__(self):
|
||||
@@ -1067,7 +1093,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
# Get first trigger input as example
|
||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
||||
header += f" Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('{example_input}')`).\n"
|
||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -1091,22 +1117,8 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# Blocks section - moved to the end with simplified format
|
||||
blocks_str = " Sub-Blocks:\n"
|
||||
for i, (name, block) in enumerate(self.sub_blocks.items()):
|
||||
# Get trigger input for this block
|
||||
trigger = None
|
||||
if hasattr(self, "block_to_trigger_map"):
|
||||
trigger = self.block_to_trigger_map.get(name)
|
||||
# Format the trigger info
|
||||
if trigger is None:
|
||||
trigger_str = "[default]"
|
||||
elif isinstance(trigger, (list, tuple)):
|
||||
trigger_str = f"[trigger: {', '.join(str(t) for t in trigger)}]"
|
||||
else:
|
||||
trigger_str = f"[trigger: {trigger}]"
|
||||
# For AutoPipelineBlocks, add bullet points
|
||||
blocks_str += f" • {name} {trigger_str} ({block.__class__.__name__})\n"
|
||||
else:
|
||||
# For SequentialPipelineBlocks, show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
# show execution order
|
||||
blocks_str += f" [{i}] {name} ({block.__class__.__name__})\n"
|
||||
|
||||
# Add block description
|
||||
desc_lines = block.description.split("\n")
|
||||
@@ -1230,15 +1242,9 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
if inp.name not in outputs and inp not in inputs:
|
||||
inputs.append(inp)
|
||||
|
||||
# Only add outputs if the block cannot be skipped
|
||||
should_add_outputs = True
|
||||
if hasattr(block, "block_trigger_inputs") and None not in block.block_trigger_inputs:
|
||||
should_add_outputs = False
|
||||
|
||||
if should_add_outputs:
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
# Add this block's outputs
|
||||
block_intermediate_outputs = [out.name for out in block.intermediate_outputs]
|
||||
outputs.update(block_intermediate_outputs)
|
||||
|
||||
for input_param in inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
@@ -1295,6 +1301,14 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
sub_blocks[block_name] = block
|
||||
self.sub_blocks = sub_blocks
|
||||
|
||||
# Validate that sub_blocks are only leaf blocks
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if block.sub_blocks:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, sub_blocks must be leaf blocks (no sub_blocks). "
|
||||
f"Block '{block_name}' ({block.__class__.__name__}) has sub_blocks."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(cls, blocks_dict: Dict[str, Any]) -> "LoopSequentialPipelineBlocks":
|
||||
"""
|
||||
|
||||
@@ -1,661 +0,0 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..image_processor import PipelineImageInput
|
||||
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks
|
||||
from .modular_pipeline_utils import InputParam
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# YiYi Notes: this is actually for SDXL, put it here for now
|
||||
SDXL_INPUTS_SCHEMA = {
|
||||
"prompt": InputParam(
|
||||
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation"
|
||||
),
|
||||
"prompt_2": InputParam(
|
||||
"prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2",
|
||||
),
|
||||
"negative_prompt": InputParam(
|
||||
"negative_prompt",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The prompt or prompts not to guide the image generation",
|
||||
),
|
||||
"negative_prompt_2": InputParam(
|
||||
"negative_prompt_2",
|
||||
type_hint=Union[str, List[str]],
|
||||
description="The negative prompt or prompts for text_encoder_2",
|
||||
),
|
||||
"cross_attention_kwargs": InputParam(
|
||||
"cross_attention_kwargs",
|
||||
type_hint=Optional[dict],
|
||||
description="Kwargs dictionary passed to the AttentionProcessor",
|
||||
),
|
||||
"clip_skip": InputParam(
|
||||
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder"
|
||||
),
|
||||
"image": InputParam(
|
||||
"image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="The image(s) to modify for img2img or inpainting",
|
||||
),
|
||||
"mask_image": InputParam(
|
||||
"mask_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Mask image for inpainting, white pixels will be repainted",
|
||||
),
|
||||
"generator": InputParam(
|
||||
"generator",
|
||||
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]],
|
||||
description="Generator(s) for deterministic generation",
|
||||
),
|
||||
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"),
|
||||
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"),
|
||||
"num_images_per_prompt": InputParam(
|
||||
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt"
|
||||
),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps"
|
||||
),
|
||||
"timesteps": InputParam(
|
||||
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process"
|
||||
),
|
||||
"sigmas": InputParam(
|
||||
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process"
|
||||
),
|
||||
"denoising_end": InputParam(
|
||||
"denoising_end",
|
||||
type_hint=Optional[float],
|
||||
description="Fraction of denoising process to complete before termination",
|
||||
),
|
||||
# YiYi Notes: img2img defaults to 0.3, inpainting defaults to 0.9999
|
||||
"strength": InputParam(
|
||||
"strength", type_hint=float, default=0.3, description="How much to transform the reference image"
|
||||
),
|
||||
"denoising_start": InputParam(
|
||||
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation"
|
||||
),
|
||||
"padding_mask_crop": InputParam(
|
||||
"padding_mask_crop",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Size of margin in crop for image and mask",
|
||||
),
|
||||
"original_size": InputParam(
|
||||
"original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Original size of the image for SDXL's micro-conditioning",
|
||||
),
|
||||
"target_size": InputParam(
|
||||
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning"
|
||||
),
|
||||
"negative_original_size": InputParam(
|
||||
"negative_original_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on image resolution",
|
||||
),
|
||||
"negative_target_size": InputParam(
|
||||
"negative_target_size",
|
||||
type_hint=Optional[Tuple[int, int]],
|
||||
description="Negative conditioning based on target resolution",
|
||||
),
|
||||
"crops_coords_top_left": InputParam(
|
||||
"crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Top-left coordinates for SDXL's micro-conditioning",
|
||||
),
|
||||
"negative_crops_coords_top_left": InputParam(
|
||||
"negative_crops_coords_top_left",
|
||||
type_hint=Tuple[int, int],
|
||||
default=(0, 0),
|
||||
description="Negative conditioning crop coordinates",
|
||||
),
|
||||
"aesthetic_score": InputParam(
|
||||
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image"
|
||||
),
|
||||
"negative_aesthetic_score": InputParam(
|
||||
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score"
|
||||
),
|
||||
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"),
|
||||
"output_type": InputParam(
|
||||
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)"
|
||||
),
|
||||
"ip_adapter_image": InputParam(
|
||||
"ip_adapter_image",
|
||||
type_hint=PipelineImageInput,
|
||||
required=True,
|
||||
description="Image(s) to be used as IP adapter",
|
||||
),
|
||||
"control_image": InputParam(
|
||||
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition"
|
||||
),
|
||||
"control_guidance_start": InputParam(
|
||||
"control_guidance_start",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=0.0,
|
||||
description="When ControlNet starts applying",
|
||||
),
|
||||
"control_guidance_end": InputParam(
|
||||
"control_guidance_end",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="When ControlNet stops applying",
|
||||
),
|
||||
"controlnet_conditioning_scale": InputParam(
|
||||
"controlnet_conditioning_scale",
|
||||
type_hint=Union[float, List[float]],
|
||||
default=1.0,
|
||||
description="Scale factor for ControlNet outputs",
|
||||
),
|
||||
"guess_mode": InputParam(
|
||||
"guess_mode",
|
||||
type_hint=bool,
|
||||
default=False,
|
||||
description="Enables ControlNet encoder to recognize input without prompts",
|
||||
),
|
||||
"control_mode": InputParam(
|
||||
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet"
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_INTERMEDIATE_INPUTS_SCHEMA = {
|
||||
"prompt_embeds": InputParam(
|
||||
"prompt_embeds",
|
||||
type_hint=torch.Tensor,
|
||||
required=True,
|
||||
description="Text embeddings used to guide image generation",
|
||||
),
|
||||
"negative_prompt_embeds": InputParam(
|
||||
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings"
|
||||
),
|
||||
"pooled_prompt_embeds": InputParam(
|
||||
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings"
|
||||
),
|
||||
"negative_pooled_prompt_embeds": InputParam(
|
||||
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings"
|
||||
),
|
||||
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"),
|
||||
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"),
|
||||
"preprocess_kwargs": InputParam(
|
||||
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor"
|
||||
),
|
||||
"latents": InputParam(
|
||||
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process"
|
||||
),
|
||||
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"),
|
||||
"num_inference_steps": InputParam(
|
||||
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps"
|
||||
),
|
||||
"latent_timestep": InputParam(
|
||||
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep"
|
||||
),
|
||||
"image_latents": InputParam(
|
||||
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image"
|
||||
),
|
||||
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"),
|
||||
"masked_image_latents": InputParam(
|
||||
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting"
|
||||
),
|
||||
"add_time_ids": InputParam(
|
||||
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning"
|
||||
),
|
||||
"negative_add_time_ids": InputParam(
|
||||
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids"
|
||||
),
|
||||
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"),
|
||||
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"),
|
||||
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"),
|
||||
"ip_adapter_embeds": InputParam(
|
||||
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter"
|
||||
),
|
||||
"negative_ip_adapter_embeds": InputParam(
|
||||
"negative_ip_adapter_embeds",
|
||||
type_hint=List[torch.Tensor],
|
||||
description="Negative image embeddings for IP-Adapter",
|
||||
),
|
||||
"images": InputParam(
|
||||
"images",
|
||||
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],
|
||||
required=True,
|
||||
description="Generated images",
|
||||
),
|
||||
}
|
||||
|
||||
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA}
|
||||
|
||||
|
||||
DEFAULT_PARAM_MAPS = {
|
||||
"prompt": {
|
||||
"label": "Prompt",
|
||||
"type": "string",
|
||||
"default": "a bear sitting in a chair drinking a milkshake",
|
||||
"display": "textarea",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"label": "Negative Prompt",
|
||||
"type": "string",
|
||||
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality",
|
||||
"display": "textarea",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"label": "Steps",
|
||||
"type": "int",
|
||||
"default": 25,
|
||||
"min": 1,
|
||||
"max": 1000,
|
||||
},
|
||||
"seed": {
|
||||
"label": "Seed",
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
"display": "random",
|
||||
},
|
||||
"width": {
|
||||
"label": "Width",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"height": {
|
||||
"label": "Height",
|
||||
"type": "int",
|
||||
"display": "text",
|
||||
"default": 1024,
|
||||
"min": 8,
|
||||
"max": 8192,
|
||||
"step": 8,
|
||||
"group": "dimensions",
|
||||
},
|
||||
"images": {
|
||||
"label": "Images",
|
||||
"type": "image",
|
||||
"display": "output",
|
||||
},
|
||||
"image": {
|
||||
"label": "Image",
|
||||
"type": "image",
|
||||
"display": "input",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_TYPE_MAPS = {
|
||||
"int": {
|
||||
"type": "int",
|
||||
"default": 0,
|
||||
"min": 0,
|
||||
},
|
||||
"float": {
|
||||
"type": "float",
|
||||
"default": 0.0,
|
||||
"min": 0.0,
|
||||
},
|
||||
"str": {
|
||||
"type": "string",
|
||||
"default": "",
|
||||
},
|
||||
"bool": {
|
||||
"type": "boolean",
|
||||
"default": False,
|
||||
},
|
||||
"image": {
|
||||
"type": "image",
|
||||
},
|
||||
}
|
||||
|
||||
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"]
|
||||
DEFAULT_CATEGORY = "Modular Diffusers"
|
||||
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"]
|
||||
DEFAULT_PARAMS_GROUPS_KEYS = {
|
||||
"text_encoders": ["text_encoder", "tokenizer"],
|
||||
"ip_adapter_embeds": ["ip_adapter_embeds"],
|
||||
"prompt_embeddings": ["prompt_embeds"],
|
||||
}
|
||||
|
||||
|
||||
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS):
|
||||
"""
|
||||
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" ->
|
||||
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None
|
||||
"""
|
||||
if name is None:
|
||||
return None
|
||||
for group_name, group_keys in group_params_keys.items():
|
||||
for group_key in group_keys:
|
||||
if group_key in name:
|
||||
return group_name
|
||||
return None
|
||||
|
||||
|
||||
class ModularNode(ConfigMixin):
|
||||
"""
|
||||
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper
|
||||
around a ModularPipelineBlocks object.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
config_name = "node_config.json"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
blocks = ModularPipelineBlocks.from_pretrained(
|
||||
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs
|
||||
)
|
||||
return cls(blocks, **kwargs)
|
||||
|
||||
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs):
|
||||
self.blocks = blocks
|
||||
|
||||
if label is None:
|
||||
label = self.blocks.__class__.__name__
|
||||
# blocks param name -> mellon param name
|
||||
self.name_mapping = {}
|
||||
|
||||
input_params = {}
|
||||
# pass or create a default param dict for each input
|
||||
# e.g. for prompt,
|
||||
# prompt = {
|
||||
# "name": "text_input", # the name of the input in node definition, could be different from the input name in diffusers
|
||||
# "label": "Prompt",
|
||||
# "type": "string",
|
||||
# "default": "a bear sitting in a chair drinking a milkshake",
|
||||
# "display": "textarea"}
|
||||
# if type is not specified, it'll be a "custom" param of its own type
|
||||
# e.g. you can pass ModularNode(scheduler = {name :"scheduler"})
|
||||
# it will get this spec in node definition {"scheduler": {"label": "Scheduler", "type": "scheduler", "display": "input"}}
|
||||
# name can be a dict, in that case, it is part of a "dict" input in mellon nodes, e.g. text_encoder= {name: {"text_encoders": "text_encoder"}}
|
||||
inputs = self.blocks.inputs + self.blocks.intermediate_inputs
|
||||
for inp in inputs:
|
||||
param = kwargs.pop(inp.name, None)
|
||||
if param:
|
||||
# user can pass a param dict for all inputs, e.g. ModularNode(prompt = {...})
|
||||
input_params[inp.name] = param
|
||||
mellon_name = param.pop("name", inp.name)
|
||||
if mellon_name != inp.name:
|
||||
self.name_mapping[inp.name] = mellon_name
|
||||
continue
|
||||
|
||||
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name):
|
||||
continue
|
||||
|
||||
if inp.name in DEFAULT_PARAM_MAPS:
|
||||
# first check if it's in the default param map, if so, directly use that
|
||||
param = DEFAULT_PARAM_MAPS[inp.name].copy()
|
||||
elif get_group_name(inp.name):
|
||||
param = get_group_name(inp.name)
|
||||
if inp.name not in self.name_mapping:
|
||||
self.name_mapping[inp.name] = param
|
||||
else:
|
||||
# if not, check if it's in the SDXL input schema, if so,
|
||||
# 1. use the type hint to determine the type
|
||||
# 2. use the default param dict for the type e.g. if "steps" is a "int" type, {"steps": {"type": "int", "default": 0, "min": 0}}
|
||||
if inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).lower()
|
||||
else:
|
||||
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None)
|
||||
type_str = str(inp_spec.type_hint).lower() if inp_spec else ""
|
||||
for type_key, type_param in DEFAULT_TYPE_MAPS.items():
|
||||
if type_key in type_str:
|
||||
param = type_param.copy()
|
||||
param["label"] = inp.name
|
||||
param["display"] = "input"
|
||||
break
|
||||
else:
|
||||
param = inp.name
|
||||
# add the param dict to the inp_params dict
|
||||
input_params[inp.name] = param
|
||||
|
||||
component_params = {}
|
||||
for comp in self.blocks.expected_components:
|
||||
param = kwargs.pop(comp.name, None)
|
||||
if param:
|
||||
component_params[comp.name] = param
|
||||
mellon_name = param.pop("name", comp.name)
|
||||
if mellon_name != comp.name:
|
||||
self.name_mapping[comp.name] = mellon_name
|
||||
continue
|
||||
|
||||
to_exclude = False
|
||||
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS:
|
||||
if exclude_key in comp.name:
|
||||
to_exclude = True
|
||||
break
|
||||
if to_exclude:
|
||||
continue
|
||||
|
||||
if get_group_name(comp.name):
|
||||
param = get_group_name(comp.name)
|
||||
if comp.name not in self.name_mapping:
|
||||
self.name_mapping[comp.name] = param
|
||||
elif comp.name in DEFAULT_MODEL_KEYS:
|
||||
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"}
|
||||
else:
|
||||
param = comp.name
|
||||
# add the param dict to the model_params dict
|
||||
component_params[comp.name] = param
|
||||
|
||||
output_params = {}
|
||||
if isinstance(self.blocks, SequentialPipelineBlocks):
|
||||
last_block_name = list(self.blocks.sub_blocks.keys())[-1]
|
||||
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs
|
||||
else:
|
||||
outputs = self.blocks.intermediate_outputs
|
||||
|
||||
for out in outputs:
|
||||
param = kwargs.pop(out.name, None)
|
||||
if param:
|
||||
output_params[out.name] = param
|
||||
mellon_name = param.pop("name", out.name)
|
||||
if mellon_name != out.name:
|
||||
self.name_mapping[out.name] = mellon_name
|
||||
continue
|
||||
|
||||
if out.name in DEFAULT_PARAM_MAPS:
|
||||
param = DEFAULT_PARAM_MAPS[out.name].copy()
|
||||
param["display"] = "output"
|
||||
else:
|
||||
group_name = get_group_name(out.name)
|
||||
if group_name:
|
||||
param = group_name
|
||||
if out.name not in self.name_mapping:
|
||||
self.name_mapping[out.name] = param
|
||||
else:
|
||||
param = out.name
|
||||
# add the param dict to the outputs dict
|
||||
output_params[out.name] = param
|
||||
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unused kwargs: {kwargs}")
|
||||
|
||||
register_dict = {
|
||||
"category": category,
|
||||
"label": label,
|
||||
"input_params": input_params,
|
||||
"component_params": component_params,
|
||||
"output_params": output_params,
|
||||
"name_mapping": self.name_mapping,
|
||||
}
|
||||
self.register_to_config(**register_dict)
|
||||
|
||||
def setup(self, components_manager, collection=None):
|
||||
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection)
|
||||
self._components_manager = components_manager
|
||||
|
||||
@property
|
||||
def mellon_config(self):
|
||||
return self._convert_to_mellon_config()
|
||||
|
||||
def _convert_to_mellon_config(self):
|
||||
node = {}
|
||||
node["label"] = self.config.label
|
||||
node["category"] = self.config.category
|
||||
|
||||
node_param = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
if inp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[inp_name]
|
||||
else:
|
||||
mellon_name = inp_name
|
||||
if isinstance(inp_param, str):
|
||||
param = {
|
||||
"label": inp_param,
|
||||
"type": inp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = inp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}")
|
||||
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
if comp_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[comp_name]
|
||||
else:
|
||||
mellon_name = comp_name
|
||||
if isinstance(comp_param, str):
|
||||
param = {
|
||||
"label": comp_param,
|
||||
"type": comp_param,
|
||||
"display": "input",
|
||||
}
|
||||
else:
|
||||
param = comp_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}")
|
||||
|
||||
for out_name, out_param in self.config.output_params.items():
|
||||
if out_name in self.name_mapping:
|
||||
mellon_name = self.name_mapping[out_name]
|
||||
else:
|
||||
mellon_name = out_name
|
||||
if isinstance(out_param, str):
|
||||
param = {
|
||||
"label": out_param,
|
||||
"type": out_param,
|
||||
"display": "output",
|
||||
}
|
||||
else:
|
||||
param = out_param
|
||||
|
||||
if mellon_name not in node_param:
|
||||
node_param[mellon_name] = param
|
||||
else:
|
||||
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}")
|
||||
node["params"] = node_param
|
||||
return node
|
||||
|
||||
def save_mellon_config(self, file_path):
|
||||
"""
|
||||
Save the Mellon configuration to a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path where the JSON file will be saved
|
||||
|
||||
Returns:
|
||||
Path: Path to the saved config file
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
os.makedirs(file_path.parent, exist_ok=True)
|
||||
|
||||
# Create a combined dictionary with module definition and name mapping
|
||||
config = {"module": self.mellon_config, "name_mapping": self.name_mapping}
|
||||
|
||||
# Save the config to file
|
||||
with open(file_path, "w", encoding="utf-8") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
logger.info(f"Mellon config and name mapping saved to {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
def load_mellon_config(cls, file_path):
|
||||
"""
|
||||
Load a Mellon configuration from a JSON file.
|
||||
|
||||
Args:
|
||||
file_path (str or Path): Path to the JSON file containing Mellon config
|
||||
|
||||
Returns:
|
||||
dict: The loaded combined configuration containing 'module' and 'name_mapping'
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"Config file not found: {file_path}")
|
||||
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
|
||||
logger.info(f"Mellon config loaded from {file_path}")
|
||||
|
||||
return config
|
||||
|
||||
def process_inputs(self, **kwargs):
|
||||
params_components = {}
|
||||
for comp_name, comp_param in self.config.component_params.items():
|
||||
logger.debug(f"component: {comp_name}")
|
||||
mellon_comp_name = self.name_mapping.get(comp_name, comp_name)
|
||||
if mellon_comp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]:
|
||||
comp = kwargs[mellon_comp_name].pop(comp_name)
|
||||
else:
|
||||
comp = kwargs.pop(mellon_comp_name)
|
||||
if comp:
|
||||
params_components[comp_name] = self._components_manager.get_one(comp["model_id"])
|
||||
|
||||
params_run = {}
|
||||
for inp_name, inp_param in self.config.input_params.items():
|
||||
logger.debug(f"input: {inp_name}")
|
||||
mellon_inp_name = self.name_mapping.get(inp_name, inp_name)
|
||||
if mellon_inp_name in kwargs:
|
||||
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]:
|
||||
inp = kwargs[mellon_inp_name].pop(inp_name)
|
||||
else:
|
||||
inp = kwargs.pop(mellon_inp_name)
|
||||
if inp is not None:
|
||||
params_run[inp_name] = inp
|
||||
|
||||
return_output_names = list(self.config.output_params.keys())
|
||||
|
||||
return params_components, params_run, return_output_names
|
||||
|
||||
def execute(self, **kwargs):
|
||||
params_components, params_run, return_output_names = self.process_inputs(**kwargs)
|
||||
|
||||
self.pipeline.update_components(**params_components)
|
||||
output = self.pipeline(**params_run, output=return_output_names)
|
||||
return output
|
||||
@@ -21,27 +21,27 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["QwenImageTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
_import_structure["modular_blocks_qwenimage"] = [
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"EDIT_BLOCKS",
|
||||
"EDIT_INPAINT_BLOCKS",
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"EDIT_PLUS_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit"] = [
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_layered"] = [
|
||||
"LAYERED_AUTO_BLOCKS",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -51,28 +51,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
QwenImageTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
from .modular_blocks_qwenimage import (
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
EDIT_AUTO_BLOCKS,
|
||||
EDIT_BLOCKS,
|
||||
EDIT_INPAINT_BLOCKS,
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
EDIT_PLUS_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit import (
|
||||
EDIT_AUTO_BLOCKS,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit_plus import (
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_layered import (
|
||||
LAYERED_AUTO_BLOCKS,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
)
|
||||
from .modular_pipeline import (
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
QwenImageLayeredModularPipeline,
|
||||
QwenImageModularPipeline,
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -23,7 +23,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from ...utils.torch_utils import randn_tensor, unwrap_module
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
|
||||
@@ -113,7 +113,9 @@ def get_timesteps(scheduler, num_inference_steps, strength):
|
||||
return timesteps, num_inference_steps - t_start
|
||||
|
||||
|
||||
# Prepare Latents steps
|
||||
# ====================
|
||||
# 1. PREPARE LATENTS
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
@@ -207,6 +209,98 @@ class QwenImagePrepareLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredPrepareLatentsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Prepare initial random noise (B, layers+1, C, H, W) for the generation process"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents"),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
InputParam(name="layers", default=4),
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="generator"),
|
||||
InputParam(
|
||||
name="batch_size",
|
||||
required=True,
|
||||
type_hint=int,
|
||||
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
|
||||
),
|
||||
InputParam(
|
||||
name="dtype",
|
||||
required=True,
|
||||
type_hint=torch.dtype,
|
||||
description="The dtype of the model inputs, can be generated in input step.",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The initial latents to use for the denoising process",
|
||||
),
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def check_inputs(height, width, vae_scale_factor):
|
||||
if height is not None and height % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}")
|
||||
|
||||
if width is not None and width % (vae_scale_factor * 2) != 0:
|
||||
raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}")
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
self.check_inputs(
|
||||
height=block_state.height,
|
||||
width=block_state.width,
|
||||
vae_scale_factor=components.vae_scale_factor,
|
||||
)
|
||||
|
||||
device = components._execution_device
|
||||
batch_size = block_state.batch_size * block_state.num_images_per_prompt
|
||||
|
||||
# we can update the height and width here since it's used to generate the initial
|
||||
block_state.height = block_state.height or components.default_height
|
||||
block_state.width = block_state.width or components.default_width
|
||||
|
||||
# VAE applies 8x compression on images but we must also account for packing which requires
|
||||
# latent height and width to be divisible by 2.
|
||||
latent_height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
|
||||
latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
|
||||
|
||||
shape = (batch_size, block_state.layers + 1, components.num_channels_latents, latent_height, latent_width)
|
||||
if isinstance(block_state.generator, list) and len(block_state.generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
if block_state.latents is None:
|
||||
block_state.latents = randn_tensor(
|
||||
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
|
||||
)
|
||||
block_state.latents = components.pachifier.pack_latents(block_state.latents)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImagePrepareLatentsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -351,7 +445,9 @@ class QwenImageCreateMaskLatentsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# Set Timesteps steps
|
||||
# ====================
|
||||
# 2. SET TIMESTEPS
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
@@ -420,6 +516,64 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Set timesteps step for QwenImage Layered with custom mu calculation based on image_latents."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("num_inference_steps", default=50, type_hint=int),
|
||||
InputParam("sigmas", type_hint=List[float]),
|
||||
InputParam("image_latents", required=True, type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="timesteps", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# Layered-specific mu calculation
|
||||
base_seqlen = 256 * 256 / 16 / 16 # = 256
|
||||
mu = (block_state.image_latents.shape[1] / base_seqlen) ** 0.5
|
||||
|
||||
# Default sigmas if not provided
|
||||
sigmas = (
|
||||
np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps)
|
||||
if block_state.sigmas is None
|
||||
else block_state.sigmas
|
||||
)
|
||||
|
||||
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
|
||||
components.scheduler,
|
||||
block_state.num_inference_steps,
|
||||
device,
|
||||
sigmas=sigmas,
|
||||
mu=mu,
|
||||
)
|
||||
|
||||
components.scheduler.set_begin_index(0)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -493,7 +647,9 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# other inputs for denoiser
|
||||
# ====================
|
||||
# 3. OTHER INPUTS FOR DENOISER
|
||||
# ====================
|
||||
|
||||
## RoPE inputs for denoiser
|
||||
|
||||
@@ -522,6 +678,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
@@ -589,6 +746,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the images latents, used for RoPE calculation",
|
||||
),
|
||||
@@ -639,19 +797,64 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for denoising process. This is used in QwenImage Edit Plus.\n"
|
||||
"Unlike Edit, Edit Plus handles lists of image_height/image_width for multiple reference images.\n"
|
||||
"Should be placed after prepare_latents step."
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="image_height", required=True, type_hint=List[int]),
|
||||
InputParam(name="image_width", required=True, type_hint=List[int]),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
description="The shapes of the image latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=List[int],
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
vae_scale_factor = components.vae_scale_factor
|
||||
|
||||
# Edit Plus: image_height and image_width are lists
|
||||
block_state.img_shapes = [
|
||||
[
|
||||
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
|
||||
*[
|
||||
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
|
||||
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
|
||||
(1, img_height // vae_scale_factor // 2, img_width // vae_scale_factor // 2)
|
||||
for img_height, img_width in zip(block_state.image_height, block_state.image_width)
|
||||
],
|
||||
]
|
||||
] * block_state.batch_size
|
||||
@@ -670,6 +873,87 @@ class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step that prepares the RoPE inputs for the denoising process. Should be place after prepare_latents step"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="layers", required=True),
|
||||
InputParam(name="height", required=True),
|
||||
InputParam(name="width", required=True),
|
||||
InputParam(name="prompt_embeds_mask"),
|
||||
InputParam(name="negative_prompt_embeds_mask"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="img_shapes",
|
||||
type_hint=List[List[Tuple[int, int, int]]],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The shapes of the image latents, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="txt_seq_lens",
|
||||
type_hint=List[int],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The sequence lengths of the prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="negative_txt_seq_lens",
|
||||
type_hint=List[int],
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The sequence lengths of the negative prompt embeds, used for RoPE calculation",
|
||||
),
|
||||
OutputParam(
|
||||
name="additional_t_cond",
|
||||
type_hint=torch.Tensor,
|
||||
kwargs_type="denoiser_input_fields",
|
||||
description="The additional t cond, used for RoPE calculation",
|
||||
),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
device = components._execution_device
|
||||
|
||||
# All shapes are the same for Layered
|
||||
shape = (
|
||||
1,
|
||||
block_state.height // components.vae_scale_factor // 2,
|
||||
block_state.width // components.vae_scale_factor // 2,
|
||||
)
|
||||
|
||||
# layers+1 output shapes + 1 condition shape (all same)
|
||||
block_state.img_shapes = [[shape] * (block_state.layers + 2)] * block_state.batch_size
|
||||
|
||||
# txt_seq_lens
|
||||
block_state.txt_seq_lens = (
|
||||
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
|
||||
)
|
||||
block_state.negative_txt_seq_lens = (
|
||||
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
|
||||
if block_state.negative_prompt_embeds_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
block_state.additional_t_cond = torch.tensor([0] * block_state.batch_size).to(device=device, dtype=torch.long)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
## ControlNet inputs for denoiser
|
||||
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -24,12 +24,13 @@ from ...models import AutoencoderKLQwenImage
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# after denoising loop (unpack latents)
|
||||
class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -71,6 +72,46 @@ class QwenImageAfterDenoiseStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredAfterDenoiseStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W) after denoising."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("height", required=True, type_hint=int),
|
||||
InputParam("width", required=True, type_hint=int),
|
||||
InputParam("layers", required=True, type_hint=int),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Unpack: (B, seq, C*4) -> (B, C, layers+1, H, W)
|
||||
block_state.latents = components.pachifier.unpack_latents(
|
||||
block_state.latents,
|
||||
block_state.height,
|
||||
block_state.width,
|
||||
block_state.layers,
|
||||
components.vae_scale_factor,
|
||||
)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# decode step
|
||||
class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -135,6 +176,81 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Decode unpacked latents (B, C, layers+1, H, W) into layer images."
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKLQwenImage),
|
||||
ComponentSpec(
|
||||
"image_processor",
|
||||
VaeImageProcessor,
|
||||
config=FrozenDict({"vae_scale_factor": 16}),
|
||||
default_creation_method="from_config",
|
||||
),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("output_type", default="pil", type_hint=str),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="images", type_hint=List[List[PIL.Image.Image]]),
|
||||
]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
latents = block_state.latents
|
||||
|
||||
# 1. VAE normalization
|
||||
latents = latents.to(components.vae.dtype)
|
||||
latents_mean = (
|
||||
torch.tensor(components.vae.config.latents_mean)
|
||||
.view(1, components.vae.config.z_dim, 1, 1, 1)
|
||||
.to(latents.device, latents.dtype)
|
||||
)
|
||||
latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
|
||||
1, components.vae.config.z_dim, 1, 1, 1
|
||||
).to(latents.device, latents.dtype)
|
||||
latents = latents / latents_std + latents_mean
|
||||
|
||||
# 2. Reshape for batch decoding: (B, C, layers+1, H, W) -> (B*layers, C, 1, H, W)
|
||||
b, c, f, h, w = latents.shape
|
||||
# 3. Remove first frame (composite), keep layers frames
|
||||
latents = latents[:, :, 1:]
|
||||
latents = latents.permute(0, 2, 1, 3, 4).reshape(-1, c, 1, h, w)
|
||||
|
||||
# 4. Decode: (B*layers, C, 1, H, W) -> (B*layers, C, H, W)
|
||||
image = components.vae.decode(latents, return_dict=False)[0]
|
||||
image = image.squeeze(2)
|
||||
|
||||
# 5. Postprocess - returns flat list of B*layers images
|
||||
image = components.image_processor.postprocess(image, output_type=block_state.output_type)
|
||||
|
||||
# 6. Chunk into list per batch item
|
||||
images = []
|
||||
for bidx in range(b):
|
||||
images.append(image[bidx * f : (bidx + 1) * f])
|
||||
|
||||
block_state.images = images
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# postprocess the decoded images
|
||||
class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
@@ -28,7 +29,12 @@ from .modular_pipeline import QwenImageModularPipeline
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# ====================
|
||||
# 1. LOOP STEPS (run at each denoising step)
|
||||
# ====================
|
||||
|
||||
|
||||
# loop step:before denoiser
|
||||
class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -60,7 +66,7 @@ class QwenImageLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditLoopBeforeDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -185,6 +191,7 @@ class QwenImageLoopBeforeDenoiserControlNet(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step:denoiser
|
||||
class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -253,6 +260,13 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
additional_cond_kwargs = {}
|
||||
for field_name, field_value in block_state.denoiser_input_fields.items():
|
||||
if field_name in transformer_args and field_name not in guider_inputs:
|
||||
additional_cond_kwargs[field_name] = field_value
|
||||
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
@@ -264,7 +278,6 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
@@ -284,7 +297,7 @@ class QwenImageLoopDenoiser(ModularPipelineBlocks):
|
||||
|
||||
|
||||
class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
model_name = "qwenimage-edit"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
@@ -351,6 +364,13 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
),
|
||||
}
|
||||
|
||||
transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys())
|
||||
additional_cond_kwargs = {}
|
||||
for field_name, field_value in block_state.denoiser_input_fields.items():
|
||||
if field_name in transformer_args and field_name not in guider_inputs:
|
||||
additional_cond_kwargs[field_name] = field_value
|
||||
block_state.additional_cond_kwargs.update(additional_cond_kwargs)
|
||||
|
||||
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
|
||||
guider_state = components.guider.prepare_inputs(guider_inputs)
|
||||
|
||||
@@ -362,7 +382,6 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
guider_state_batch.noise_pred = components.transformer(
|
||||
hidden_states=block_state.latent_model_input,
|
||||
timestep=block_state.timestep / 1000,
|
||||
img_shapes=block_state.img_shapes,
|
||||
attention_kwargs=block_state.attention_kwargs,
|
||||
return_dict=False,
|
||||
**cond_kwargs,
|
||||
@@ -384,6 +403,7 @@ class QwenImageEditLoopDenoiser(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# loop step:after denoiser
|
||||
class QwenImageLoopAfterDenoiser(ModularPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -481,6 +501,9 @@ class QwenImageLoopAfterDenoiserInpaint(ModularPipelineBlocks):
|
||||
return components, block_state
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE LOOP WRAPPER: define the denoising loop logic
|
||||
# ====================
|
||||
class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
@@ -537,8 +560,15 @@ class QwenImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
# ====================
|
||||
# 3. DENOISE STEPS: compose the denoising loop with loop wrapper + loop steps
|
||||
# ====================
|
||||
|
||||
|
||||
# Qwen Image (text2image, image2image)
|
||||
class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
@@ -559,8 +589,9 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the inpainting denoising loops
|
||||
# Qwen Image (inpainting)
|
||||
class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopDenoiser,
|
||||
@@ -583,8 +614,9 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
# Qwen Image (text2image, image2image) with controlnet
|
||||
class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
@@ -607,8 +639,9 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the controlnet denoising loops
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageLoopBeforeDenoiser,
|
||||
QwenImageLoopBeforeDenoiserControlNet,
|
||||
@@ -639,8 +672,9 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# composing the denoising loops
|
||||
# Qwen Image Edit (image2image)
|
||||
class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
@@ -661,7 +695,9 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting)
|
||||
class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
@@ -682,3 +718,26 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
" - `QwenImageLoopAfterDenoiserInpaint`\n"
|
||||
"This block supports inpainting tasks for QwenImage Edit."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image)
|
||||
class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageEditLoopBeforeDenoiser,
|
||||
QwenImageEditLoopDenoiser,
|
||||
QwenImageLoopAfterDenoiser,
|
||||
]
|
||||
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. \n"
|
||||
"Its loop logic is defined in `QwenImageDenoiseLoopWrapper.__call__` method \n"
|
||||
"At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
|
||||
" - `QwenImageEditLoopBeforeDenoiser`\n"
|
||||
" - `QwenImageEditLoopDenoiser`\n"
|
||||
" - `QwenImageLoopAfterDenoiser`\n"
|
||||
"This block supports QwenImage Layered."
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@ import torch
|
||||
from ...models import QwenImageMultiControlNetModel
|
||||
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
|
||||
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
|
||||
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
|
||||
from .modular_pipeline import QwenImageLayeredPachifier, QwenImageModularPipeline, QwenImagePachifier
|
||||
|
||||
|
||||
def repeat_tensor_to_batch_size(
|
||||
@@ -221,37 +221,16 @@ class QwenImageTextInputsStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
class QwenImageAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage: update height/width, expand batch, patchify."""
|
||||
|
||||
model_name = "qwenimage"
|
||||
|
||||
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
|
||||
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
|
||||
|
||||
This step handles multiple common tasks to prepare inputs for the denoising step:
|
||||
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
|
||||
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
|
||||
|
||||
This is a dynamic block that allows you to configure which inputs to process.
|
||||
|
||||
Args:
|
||||
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
|
||||
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
|
||||
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
|
||||
additional_batch_inputs (List[str], optional):
|
||||
Names of additional conditional input tensors to expand batch size. These tensors will only have their
|
||||
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
|
||||
Defaults to []. Examples: ["processed_mask_image"]
|
||||
|
||||
Examples:
|
||||
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
|
||||
|
||||
# Configure to process multiple image latent inputs
|
||||
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
|
||||
|
||||
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
)
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
@@ -263,14 +242,12 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
# Functionality section
|
||||
summary_section = (
|
||||
"Input processing step that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
# Inputs info
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
@@ -279,11 +256,16 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
# Placement guidance
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
@@ -293,11 +275,9 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
# Add image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
# Add additional batch inputs
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
@@ -306,26 +286,28 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=int, description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=int, description="The width of the image latents"),
|
||||
]
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
description="The image height calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
# 1. Calculate height/width from latents and update if not provided
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
block_state.height = block_state.height or height
|
||||
block_state.width = block_state.width or width
|
||||
@@ -335,7 +317,7 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# 2. Patchify
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
@@ -354,7 +336,6 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
@@ -368,63 +349,270 @@ class QwenImageInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
|
||||
class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Edit Plus: handles list of latents with different sizes."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step for Edit Plus that:\n"
|
||||
" 1. For image latent inputs (list): Collects heights/widths, patchifies each, concatenates, expands batch\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size\n"
|
||||
" Height/width defaults to last image in the list."
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
InputParam(name="height"),
|
||||
InputParam(name="width"),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
|
||||
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=List[int],
|
||||
description="The image heights calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=List[int],
|
||||
description="The image widths calculated from the image latents dimension",
|
||||
),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# Each image latent can have different size in QwenImage Edit Plus.
|
||||
is_list = isinstance(image_latent_tensor, list)
|
||||
if not is_list:
|
||||
image_latent_tensor = [image_latent_tensor]
|
||||
|
||||
image_heights = []
|
||||
image_widths = []
|
||||
packed_image_latent_tensors = []
|
||||
|
||||
for img_latent_tensor in image_latent_tensor:
|
||||
for i, img_latent_tensor in enumerate(image_latent_tensor):
|
||||
# 1. Calculate height/width from latents
|
||||
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
|
||||
image_heights.append(height)
|
||||
image_widths.append(width)
|
||||
|
||||
# 2. Patchify the image latent tensor
|
||||
# 2. Patchify
|
||||
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
img_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_name=f"{image_latent_input_name}[{i}]",
|
||||
input_tensor=img_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
packed_image_latent_tensors.append(img_latent_tensor)
|
||||
|
||||
# Concatenate all packed latents along dim=1
|
||||
packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
|
||||
|
||||
# Output lists of heights/widths
|
||||
block_state.image_height = image_heights
|
||||
block_state.image_width = image_widths
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Default height/width from last image
|
||||
block_state.height = block_state.height or image_heights[-1]
|
||||
block_state.width = block_state.width or image_widths[-1]
|
||||
|
||||
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, input_name, input_tensor)
|
||||
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
|
||||
# YiYi TODO: support define config default component from the ModularPipeline level.
|
||||
# it is same as QwenImageAdditionalInputsStep, but with layered pachifier.
|
||||
class QwenImageLayeredAdditionalInputsStep(ModularPipelineBlocks):
|
||||
"""Input step for QwenImage Layered: update height/width, expand batch, patchify with layered pachifier."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_latent_inputs: List[str] = ["image_latents"],
|
||||
additional_batch_inputs: List[str] = [],
|
||||
):
|
||||
if not isinstance(image_latent_inputs, list):
|
||||
image_latent_inputs = [image_latent_inputs]
|
||||
if not isinstance(additional_batch_inputs, list):
|
||||
additional_batch_inputs = [additional_batch_inputs]
|
||||
|
||||
self._image_latent_inputs = image_latent_inputs
|
||||
self._additional_batch_inputs = additional_batch_inputs
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
summary_section = (
|
||||
"Input processing step for Layered that:\n"
|
||||
" 1. For image latent inputs: Updates height/width if None, patchifies with layered pachifier, and expands batch size\n"
|
||||
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
|
||||
)
|
||||
|
||||
inputs_info = ""
|
||||
if self._image_latent_inputs or self._additional_batch_inputs:
|
||||
inputs_info = "\n\nConfigured inputs:"
|
||||
if self._image_latent_inputs:
|
||||
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
|
||||
if self._additional_batch_inputs:
|
||||
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
|
||||
|
||||
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
|
||||
|
||||
return summary_section + inputs_info + placement_section
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("pachifier", QwenImageLayeredPachifier, default_creation_method="from_config"),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[InputParam]:
|
||||
inputs = [
|
||||
InputParam(name="num_images_per_prompt", default=1),
|
||||
InputParam(name="batch_size", required=True),
|
||||
]
|
||||
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
inputs.append(InputParam(name=image_latent_input_name))
|
||||
|
||||
for input_name in self._additional_batch_inputs:
|
||||
inputs.append(InputParam(name=input_name))
|
||||
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
OutputParam(
|
||||
name="image_height",
|
||||
type_hint=int,
|
||||
description="The image height calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(
|
||||
name="image_width",
|
||||
type_hint=int,
|
||||
description="The image width calculated from the image latents dimension",
|
||||
),
|
||||
OutputParam(name="height", type_hint=int, description="The height of the image output"),
|
||||
OutputParam(name="width", type_hint=int, description="The width of the image output"),
|
||||
]
|
||||
|
||||
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Process image latent inputs
|
||||
for image_latent_input_name in self._image_latent_inputs:
|
||||
image_latent_tensor = getattr(block_state, image_latent_input_name)
|
||||
if image_latent_tensor is None:
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents and update if not provided
|
||||
# Layered latents are (B, layers, C, H, W)
|
||||
height = image_latent_tensor.shape[3] * components.vae_scale_factor
|
||||
width = image_latent_tensor.shape[4] * components.vae_scale_factor
|
||||
block_state.height = height
|
||||
block_state.width = width
|
||||
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
if not hasattr(block_state, "image_width"):
|
||||
block_state.image_width = width
|
||||
|
||||
# 2. Patchify with layered pachifier
|
||||
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
|
||||
|
||||
# 3. Expand batch size
|
||||
image_latent_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=image_latent_input_name,
|
||||
input_tensor=image_latent_tensor,
|
||||
num_images_per_prompt=block_state.num_images_per_prompt,
|
||||
batch_size=block_state.batch_size,
|
||||
)
|
||||
|
||||
setattr(block_state, image_latent_input_name, image_latent_tensor)
|
||||
|
||||
# Process additional batch inputs (only batch expansion)
|
||||
for input_name in self._additional_batch_inputs:
|
||||
input_tensor = getattr(block_state, input_name)
|
||||
if input_tensor is None:
|
||||
continue
|
||||
|
||||
# Only expand batch size
|
||||
input_tensor = repeat_tensor_to_batch_size(
|
||||
input_name=input_name,
|
||||
input_tensor=input_tensor,
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,469 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageControlNetBeforeDenoiserStep,
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageRoPEInputsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageControlNetDenoiseStep,
|
||||
QwenImageDenoiseStep,
|
||||
QwenImageInpaintControlNetDenoiseStep,
|
||||
QwenImageInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageControlNetVaeEncoderStep,
|
||||
QwenImageInpaintProcessImagesInputStep,
|
||||
QwenImageProcessImagesInputStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageAdditionalInputsStep,
|
||||
QwenImageControlNetInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageInpaintProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for inpainting tasks. It:\n"
|
||||
" - Resizes the image to the target size, based on `height` and `width`.\n"
|
||||
" - Processes and updates `image` and `mask_image`.\n"
|
||||
" - Creates `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = [QwenImageProcessImagesInputStep(), QwenImageVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintVaeEncoderStep, QwenImageImg2ImgVaeEncoderStep]
|
||||
block_names = ["inpaint", "img2img"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageInpaintVaeEncoderStep` (inpaint) is used when `mask_image` is provided.\n"
|
||||
+ " - `QwenImageImg2ImgVaeEncoderStep` (img2img) is used when `image` is provided.\n"
|
||||
+ " - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# optional controlnet vae encoder
|
||||
class QwenImageOptionalControlNetVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageControlNetVaeEncoderStep]
|
||||
block_names = ["controlnet"]
|
||||
block_trigger_inputs = ["control_image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block.\n"
|
||||
+ " - `QwenImageControlNetVaeEncoderStep` (controlnet) is used when `control_image` is provided.\n"
|
||||
+ " - if `control_image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageTextInputsStep(), QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"])]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
class QwenImageInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the inpainting denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents` and `processed_mask_image`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
class QwenImageInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the pachified latents `mask` based on the processedmask image.\n"
|
||||
)
|
||||
|
||||
|
||||
# assemble denoising steps
|
||||
|
||||
|
||||
# Qwen Image (text2image)
|
||||
class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# Qwen Image (inpainting)
|
||||
class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# Qwen Image (image2image)
|
||||
class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# Qwen Image (text2image) with controlnet
|
||||
class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "step that denoise noise into image for text2image task. It includes the denoise loop, as well as prepare the inputs (timesteps, latents, rope inputs etc.)."
|
||||
|
||||
|
||||
# Qwen Image (inpainting) with controlnet
|
||||
class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageInpaintInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageInpaintPrepareLatentsStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageInpaintControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for inpaint task."
|
||||
|
||||
|
||||
# Qwen Image (image2image) with controlnet
|
||||
class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [
|
||||
QwenImageImg2ImgInputStep(),
|
||||
QwenImageControlNetInputsStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImagePrepareLatentsWithStrengthStep(),
|
||||
QwenImageRoPEInputsStep(),
|
||||
QwenImageControlNetBeforeDenoiserStep(),
|
||||
QwenImageControlNetDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"controlnet_input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_img2img_latents",
|
||||
"prepare_rope_inputs",
|
||||
"controlnet_before_denoise",
|
||||
"controlnet_denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for img2img task."
|
||||
|
||||
|
||||
# Auto denoise step for QwenImage
|
||||
class QwenImageAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
block_classes = [
|
||||
QwenImageCoreDenoiseStep,
|
||||
QwenImageInpaintCoreDenoiseStep,
|
||||
QwenImageImg2ImgCoreDenoiseStep,
|
||||
QwenImageControlNetCoreDenoiseStep,
|
||||
QwenImageControlNetInpaintCoreDenoiseStep,
|
||||
QwenImageControlNetImg2ImgCoreDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"text2image",
|
||||
"inpaint",
|
||||
"img2img",
|
||||
"controlnet_text2image",
|
||||
"controlnet_inpaint",
|
||||
"controlnet_img2img",
|
||||
]
|
||||
block_trigger_inputs = ["control_image_latents", "processed_mask_image", "image_latents"]
|
||||
default_block_name = "text2image"
|
||||
|
||||
def select_block(self, control_image_latents=None, processed_mask_image=None, image_latents=None):
|
||||
if control_image_latents is not None:
|
||||
if processed_mask_image is not None:
|
||||
return "controlnet_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "controlnet_img2img"
|
||||
else:
|
||||
return "controlnet_text2image"
|
||||
else:
|
||||
if processed_mask_image is not None:
|
||||
return "inpaint"
|
||||
elif image_latents is not None:
|
||||
return "img2img"
|
||||
else:
|
||||
return "text2image"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `QwenImageCoreDenoiseStep` (text2image) for text2image tasks.\n"
|
||||
+ " - `QwenImageInpaintCoreDenoiseStep` (inpaint) for inpaint tasks.\n"
|
||||
+ " - `QwenImageImg2ImgCoreDenoiseStep` (img2img) for img2img tasks.\n"
|
||||
+ " - `QwenImageControlNetCoreDenoiseStep` (controlnet_text2image) for text2image tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetInpaintCoreDenoiseStep` (controlnet_inpaint) for inpaint tasks with controlnet.\n"
|
||||
+ " - `QwenImageControlNetImg2ImgCoreDenoiseStep` (controlnet_img2img) for img2img tasks with controlnet.\n"
|
||||
+ "This step support text-to-image, image-to-image, inpainting, and controlnet tasks for QwenImage:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for inpainting, you need to provide `processed_mask_image` and `image_latents`\n"
|
||||
+ " - to run the controlnet workflow, you need to provide `control_image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# standard decode step works for most tasks except for inpaint
|
||||
class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optional apply the mask overally to the original image."
|
||||
|
||||
|
||||
# Auto decode step for QwenImage
|
||||
class QwenImageAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageInpaintDecodeStep, QwenImageDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images. \n"
|
||||
" This is an auto pipeline block that works for inpaint/text2image/img2img tasks, for both QwenImage and QwenImage-Edit.\n"
|
||||
+ " - `QwenImageInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
+ " - `QwenImageDecodeStep` (text2image/img2img) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageTextEncoderStep()),
|
||||
("vae_encoder", QwenImageAutoVaeEncoderStep()),
|
||||
("controlnet_vae_encoder", QwenImageOptionalControlNetVaeEncoderStep()),
|
||||
("denoise", QwenImageAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
@@ -0,0 +1,336 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, ConditionalPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageCreateMaskLatentsStep,
|
||||
QwenImageEditRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImagePrepareLatentsWithStrengthStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
QwenImageSetTimestepsWithStrengthStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageInpaintProcessImagesOutputStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
QwenImageEditInpaintDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditInpaintProcessImagesInputStep,
|
||||
QwenImageEditProcessImagesInputStep,
|
||||
QwenImageEditResizeStep,
|
||||
QwenImageEditTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts."""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit VL encoder step that encode the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# Edit Inpaint VAE encoder
|
||||
class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditResizeStep(),
|
||||
QwenImageEditInpaintProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(input_name="processed_image", output_name="image_latents"),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step is used for processing image and mask inputs for QwenImage-Edit inpaint tasks. It:\n"
|
||||
" - resize the image for target area (1024 * 1024) while maintaining the aspect ratio.\n"
|
||||
" - process the resized image and mask image.\n"
|
||||
" - create image latents."
|
||||
)
|
||||
|
||||
|
||||
# Auto VAE encoder
|
||||
class QwenImageEditAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintVaeEncoderStep, QwenImageEditVaeEncoderStep]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["mask_image", "image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintVaeEncoderStep` (edit_inpaint) is used when `mask_image` is provided.\n"
|
||||
" - `QwenImageEditVaeEncoderStep` (edit) is used when `image` is provided.\n"
|
||||
" - if `mask_image` or `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageEditInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditInpaintInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageAdditionalInputsStep(
|
||||
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
|
||||
),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the edit inpaint denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# assemble prepare latents steps
|
||||
class QwenImageEditInpaintPrepareLatentsStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImagePrepareLatentsWithStrengthStep(), QwenImageCreateMaskLatentsStep()]
|
||||
block_names = ["add_noise_to_latents", "create_mask_latents"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"This step prepares the latents/image_latents and mask inputs for the edit inpainting denoising step. It:\n"
|
||||
" - Add noise to the image latents to create the latents input for the denoiser.\n"
|
||||
" - Create the patchified latents `mask` based on the processed mask image.\n"
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit (image2image) core denoise step
|
||||
class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit (img2img) task."
|
||||
|
||||
|
||||
# Qwen Image Edit (inpainting) core denoise step
|
||||
class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsWithStrengthStep(),
|
||||
QwenImageEditInpaintPrepareLatentsStep(),
|
||||
QwenImageEditRoPEInputsStep(),
|
||||
QwenImageEditInpaintDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_inpaint_latents",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit edit inpaint task."
|
||||
|
||||
|
||||
# Auto core denoise step for QwenImage Edit
|
||||
class QwenImageEditAutoCoreDenoiseStep(ConditionalPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [
|
||||
QwenImageEditInpaintCoreDenoiseStep,
|
||||
QwenImageEditCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["edit_inpaint", "edit"]
|
||||
block_trigger_inputs = ["processed_mask_image", "image_latents"]
|
||||
default_block_name = "edit"
|
||||
|
||||
def select_block(self, processed_mask_image=None, image_latents=None) -> Optional[str]:
|
||||
if processed_mask_image is not None:
|
||||
return "edit_inpaint"
|
||||
elif image_latents is not None:
|
||||
return "edit"
|
||||
return None
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto core denoising step that selects the appropriate workflow based on inputs.\n"
|
||||
" - `QwenImageEditInpaintCoreDenoiseStep` when `processed_mask_image` is provided\n"
|
||||
" - `QwenImageEditCoreDenoiseStep` when `image_latents` is provided\n"
|
||||
"Supports edit (img2img) and edit inpainting tasks for QwenImage-Edit."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
# Decode step (standard)
|
||||
class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image."
|
||||
|
||||
|
||||
# Inpaint decode step
|
||||
class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageInpaintProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocess the generated image, optionally apply the mask overlay to the original image."
|
||||
|
||||
|
||||
# Auto decode step
|
||||
class QwenImageEditAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [QwenImageEditInpaintDecodeStep, QwenImageEditDecodeStep]
|
||||
block_names = ["inpaint_decode", "decode"]
|
||||
block_trigger_inputs = ["mask", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Decode step that decode the latents into images.\n"
|
||||
"This is an auto pipeline block.\n"
|
||||
" - `QwenImageEditInpaintDecodeStep` (inpaint) is used when `mask` is provided.\n"
|
||||
" - `QwenImageEditDecodeStep` (edit) is used when `mask` is not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditAutoVaeEncoderStep()),
|
||||
("denoise", QwenImageEditAutoCoreDenoiseStep()),
|
||||
("decode", QwenImageEditAutoDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) and edit inpaint tasks using QwenImage-Edit.\n"
|
||||
"- for edit (img2img) generation, you need to provide `image`\n"
|
||||
"- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`\n"
|
||||
)
|
||||
@@ -0,0 +1,181 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageEditPlusRoPEInputsStep,
|
||||
QwenImagePrepareLatentsStep,
|
||||
QwenImageSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageAfterDenoiseStep,
|
||||
QwenImageDecoderStep,
|
||||
QwenImageProcessImagesOutputStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageEditDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditPlusProcessImagesInputStep,
|
||||
QwenImageEditPlusResizeStep,
|
||||
QwenImageEditPlusTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageEditPlusAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
"""VL encoder that takes both image and text prompts. Uses 384x384 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=384 * 384, output_name="resized_cond_image"),
|
||||
QwenImageEditPlusTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Edit Plus VL encoder step that encodes the image and text prompts together."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""VAE encoder that handles multiple images with different sizes. Uses 1024x1024 target area."""
|
||||
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusResizeStep(target_area=1024 * 1024, output_name="resized_image"),
|
||||
QwenImageEditPlusProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"VAE encoder step that encodes image inputs into latent representations.\n"
|
||||
"Each image is resized independently based on its own aspect ratio to 1024x1024 target area."
|
||||
)
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageEditPlusAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the Edit Plus denoising step. It:\n"
|
||||
" - Standardizes text embeddings batch size.\n"
|
||||
" - Processes list of image latents: patchifies, concatenates along dim=1, expands batch.\n"
|
||||
" - Outputs lists of image_height/image_width for RoPE calculation.\n"
|
||||
" - Defaults height/width from last image in the list."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Edit Plus (image2image) core denoise step
|
||||
class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [
|
||||
QwenImageEditPlusInputStep(),
|
||||
QwenImagePrepareLatentsStep(),
|
||||
QwenImageSetTimestepsStep(),
|
||||
QwenImageEditPlusRoPEInputsStep(),
|
||||
QwenImageEditDenoiseStep(),
|
||||
QwenImageAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Edit Plus edit (img2img) task."
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. DECODE
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = [QwenImageDecoderStep(), QwenImageProcessImagesOutputStep()]
|
||||
block_names = ["decode", "postprocess"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decodes the latents to images and postprocesses the generated image."
|
||||
|
||||
|
||||
# ====================
|
||||
# 5. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
EDIT_PLUS_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageEditPlusVLEncoderStep()),
|
||||
("vae_encoder", QwenImageEditPlusVaeEncoderStep()),
|
||||
("denoise", QwenImageEditPlusCoreDenoiseStep()),
|
||||
("decode", QwenImageEditPlusDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-edit-plus"
|
||||
block_classes = EDIT_PLUS_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_PLUS_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for edit (img2img) tasks using QwenImage-Edit Plus.\n"
|
||||
"- `image` is required input (can be single image or list of images).\n"
|
||||
"- Each image is resized independently based on its own aspect ratio.\n"
|
||||
"- VL encoder uses 384x384 target area, VAE encoder uses 1024x1024 target area."
|
||||
)
|
||||
@@ -0,0 +1,159 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
QwenImageLayeredPrepareLatentsStep,
|
||||
QwenImageLayeredRoPEInputsStep,
|
||||
QwenImageLayeredSetTimestepsStep,
|
||||
)
|
||||
from .decoders import (
|
||||
QwenImageLayeredAfterDenoiseStep,
|
||||
QwenImageLayeredDecoderStep,
|
||||
)
|
||||
from .denoise import (
|
||||
QwenImageLayeredDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
QwenImageEditProcessImagesInputStep,
|
||||
QwenImageLayeredGetImagePromptStep,
|
||||
QwenImageLayeredPermuteLatentsStep,
|
||||
QwenImageLayeredResizeStep,
|
||||
QwenImageTextEncoderStep,
|
||||
QwenImageVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
QwenImageLayeredAdditionalInputsStep,
|
||||
QwenImageTextInputsStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. TEXT ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
"""Text encoder that takes text prompt, will generate a prompt based on image if not provided."""
|
||||
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
QwenImageLayeredGetImagePromptStep(),
|
||||
QwenImageTextEncoderStep(),
|
||||
]
|
||||
block_names = ["resize", "get_image_prompt", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "QwenImage-Layered Text encoder step that encode the text prompt, will generate a prompt based on image if not provided."
|
||||
|
||||
|
||||
# ====================
|
||||
# 2. VAE ENCODER
|
||||
# ====================
|
||||
|
||||
|
||||
# Edit VAE encoder
|
||||
class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredResizeStep(),
|
||||
QwenImageEditProcessImagesInputStep(),
|
||||
QwenImageVaeEncoderStep(),
|
||||
QwenImageLayeredPermuteLatentsStep(),
|
||||
]
|
||||
block_names = ["resize", "preprocess", "encode", "permute"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that encode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# ====================
|
||||
# 3. DENOISE (input -> prepare_latents -> set_timesteps -> prepare_rope_inputs -> denoise -> after_denoise)
|
||||
# ====================
|
||||
|
||||
|
||||
# assemble input steps
|
||||
class QwenImageLayeredInputStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageTextInputsStep(),
|
||||
QwenImageLayeredAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the layered denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs.\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# Qwen Image Layered (image2image) core denoise step
|
||||
class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = [
|
||||
QwenImageLayeredInputStep(),
|
||||
QwenImageLayeredPrepareLatentsStep(),
|
||||
QwenImageLayeredSetTimestepsStep(),
|
||||
QwenImageLayeredRoPEInputsStep(),
|
||||
QwenImageLayeredDenoiseStep(),
|
||||
QwenImageLayeredAfterDenoiseStep(),
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"prepare_rope_inputs",
|
||||
"denoise",
|
||||
"after_denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoising workflow for QwenImage-Layered img2img task."
|
||||
|
||||
|
||||
# ====================
|
||||
# 4. AUTO BLOCKS & PRESETS
|
||||
# ====================
|
||||
|
||||
LAYERED_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", QwenImageLayeredTextEncoderStep()),
|
||||
("vae_encoder", QwenImageLayeredVaeEncoderStep()),
|
||||
("denoise", QwenImageLayeredCoreDenoiseStep()),
|
||||
("decode", QwenImageLayeredDecoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "qwenimage-layered"
|
||||
block_classes = LAYERED_AUTO_BLOCKS.values()
|
||||
block_names = LAYERED_AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for layered denoising tasks using QwenImage-Layered."
|
||||
@@ -90,6 +90,88 @@ class QwenImagePachifier(ConfigMixin):
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageLayeredPachifier(ConfigMixin):
|
||||
"""
|
||||
A class to pack and unpack latents for QwenImage Layered.
|
||||
|
||||
Unlike QwenImagePachifier, this handles 5D latents with shape (B, layers+1, C, H, W).
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, patch_size: int = 2):
|
||||
super().__init__()
|
||||
|
||||
def pack_latents(self, latents):
|
||||
"""
|
||||
Pack latents from (B, layers, C, H, W) to (B, layers * H/2 * W/2, C*4).
|
||||
"""
|
||||
|
||||
if latents.ndim != 5:
|
||||
raise ValueError(f"Latents must have 5 dimensions (B, layers, C, H, W), but got {latents.ndim}")
|
||||
|
||||
batch_size, layers, num_channels_latents, latent_height, latent_width = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
if latent_height % patch_size != 0 or latent_width % patch_size != 0:
|
||||
raise ValueError(
|
||||
f"Latent height and width must be divisible by {patch_size}, but got {latent_height} and {latent_width}"
|
||||
)
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
layers,
|
||||
num_channels_latents,
|
||||
latent_height // patch_size,
|
||||
patch_size,
|
||||
latent_width // patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 1, 3, 5, 2, 4, 6)
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
layers * (latent_height // patch_size) * (latent_width // patch_size),
|
||||
num_channels_latents * patch_size * patch_size,
|
||||
)
|
||||
return latents
|
||||
|
||||
def unpack_latents(self, latents, height, width, layers, vae_scale_factor=8):
|
||||
"""
|
||||
Unpack latents from (B, seq, C*4) to (B, C, layers+1, H, W).
|
||||
"""
|
||||
|
||||
if latents.ndim != 3:
|
||||
raise ValueError(f"Latents must have 3 dimensions, but got {latents.ndim}")
|
||||
|
||||
batch_size, _, channels = latents.shape
|
||||
patch_size = self.config.patch_size
|
||||
|
||||
height = patch_size * (int(height) // (vae_scale_factor * patch_size))
|
||||
width = patch_size * (int(width) // (vae_scale_factor * patch_size))
|
||||
|
||||
latents = latents.view(
|
||||
batch_size,
|
||||
layers + 1,
|
||||
height // patch_size,
|
||||
width // patch_size,
|
||||
channels // (patch_size * patch_size),
|
||||
patch_size,
|
||||
patch_size,
|
||||
)
|
||||
latents = latents.permute(0, 1, 4, 2, 5, 3, 6)
|
||||
latents = latents.reshape(
|
||||
batch_size,
|
||||
layers + 1,
|
||||
channels // (patch_size * patch_size),
|
||||
height,
|
||||
width,
|
||||
)
|
||||
latents = latents.permute(0, 2, 1, 3, 4) # (b, c, f, h, w)
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
class QwenImageModularPipeline(ModularPipeline, QwenImageLoraLoaderMixin):
|
||||
"""
|
||||
A ModularPipeline for QwenImage.
|
||||
@@ -203,3 +285,13 @@ class QwenImageEditPlusModularPipeline(QwenImageEditModularPipeline):
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageEditPlusAutoBlocks"
|
||||
|
||||
|
||||
class QwenImageLayeredModularPipeline(QwenImageModularPipeline):
|
||||
"""
|
||||
A ModularPipeline for QwenImage-Layered.
|
||||
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
"""
|
||||
|
||||
default_blocks_name = "QwenImageLayeredAutoBlocks"
|
||||
|
||||
121
src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
Normal file
121
src/diffusers/modular_pipelines/qwenimage/prompt_templates.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Prompt templates for QwenImage pipelines.
|
||||
|
||||
This module centralizes all prompt templates used across different QwenImage pipeline variants:
|
||||
- QwenImage (base): Text-only encoding for text-to-image generation
|
||||
- QwenImage Edit: VL encoding with single image for image editing
|
||||
- QwenImage Edit Plus: VL encoding with multiple images for multi-reference editing
|
||||
- QwenImage Layered: Auto-captioning for image decomposition
|
||||
"""
|
||||
|
||||
# ============================================
|
||||
# QwenImage Base (text-only encoding)
|
||||
# ============================================
|
||||
# Used for text-to-image generation where only text prompt is encoded
|
||||
|
||||
QWENIMAGE_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the image by detailing the color, shape, size, texture, quantity, text, "
|
||||
"spatial relationships of the objects and background:<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_PROMPT_TEMPLATE_START_IDX = 34
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Edit (VL encoding with single image)
|
||||
# ============================================
|
||||
# Used for single-image editing where both image and text are encoded together
|
||||
|
||||
QWENIMAGE_EDIT_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
||||
"then explain how the user's text instruction should alter or modify the image. "
|
||||
"Generate a new image that meets the user's requirements while maintaining consistency "
|
||||
"with the original input where appropriate.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_EDIT_PROMPT_TEMPLATE_START_IDX = 64
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Edit Plus (VL encoding with multiple images)
|
||||
# ============================================
|
||||
# Used for multi-reference editing where multiple images and text are encoded together
|
||||
# The img_template is used to format each image in the prompt
|
||||
|
||||
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE = (
|
||||
"<|im_start|>system\n"
|
||||
"Describe the key features of the input image (color, shape, size, texture, objects, background), "
|
||||
"then explain how the user's text instruction should alter or modify the image. "
|
||||
"Generate a new image that meets the user's requirements while maintaining consistency "
|
||||
"with the original input where appropriate.<|im_end|>\n"
|
||||
"<|im_start|>user\n{}<|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
QWENIMAGE_EDIT_PLUS_IMG_TEMPLATE = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
|
||||
QWENIMAGE_EDIT_PLUS_PROMPT_TEMPLATE_START_IDX = 64
|
||||
|
||||
|
||||
# ============================================
|
||||
# QwenImage Layered (auto-captioning)
|
||||
# ============================================
|
||||
# Used for image decomposition where the VL model generates a caption from the input image
|
||||
# if no prompt is provided. These prompts instruct the model to describe the image in detail.
|
||||
|
||||
QWENIMAGE_LAYERED_CAPTION_PROMPT_EN = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"# Image Annotator\n"
|
||||
"You are a professional image annotator. Please write an image caption based on the input image:\n"
|
||||
"1. Write the caption using natural, descriptive language without structured formats or rich text.\n"
|
||||
"2. Enrich caption details by including:\n"
|
||||
" - Object attributes, such as quantity, color, shape, size, material, state, position, actions, and so on\n"
|
||||
" - Vision Relations between objects, such as spatial relations, functional relations, possessive relations, "
|
||||
"attachment relations, action relations, comparative relations, causal relations, and so on\n"
|
||||
" - Environmental details, such as weather, lighting, colors, textures, atmosphere, and so on\n"
|
||||
" - Identify the text clearly visible in the image, without translation or explanation, "
|
||||
"and highlight it in the caption with quotation marks\n"
|
||||
"3. Maintain authenticity and accuracy:\n"
|
||||
" - Avoid generalizations\n"
|
||||
" - Describe all visible information in the image, while do not add information not explicitly shown in the image\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
|
||||
QWENIMAGE_LAYERED_CAPTION_PROMPT_CN = (
|
||||
"<|im_start|>system\n"
|
||||
"You are a helpful assistant.<|im_end|>\n"
|
||||
"<|im_start|>user\n"
|
||||
"# 图像标注器\n"
|
||||
"你是一个专业的图像标注器。请基于输入图像,撰写图注:\n"
|
||||
"1. 使用自然、描述性的语言撰写图注,不要使用结构化形式或富文本形式。\n"
|
||||
"2. 通过加入以下内容,丰富图注细节:\n"
|
||||
" - 对象的属性:如数量、颜色、形状、大小、位置、材质、状态、动作等\n"
|
||||
" - 对象间的视觉关系:如空间关系、功能关系、动作关系、从属关系、比较关系、因果关系等\n"
|
||||
" - 环境细节:例如天气、光照、颜色、纹理、气氛等\n"
|
||||
" - 文字内容:识别图像中清晰可见的文字,不做翻译和解释,用引号在图注中强调\n"
|
||||
"3. 保持真实性与准确性:\n"
|
||||
" - 不要使用笼统的描述\n"
|
||||
" - 描述图像中所有可见的信息,但不要加入没有在图像中出现的内容\n"
|
||||
"<|vision_start|><|image_pad|><|vision_end|><|im_end|>\n"
|
||||
"<|im_start|>assistant\n"
|
||||
)
|
||||
@@ -288,7 +288,9 @@ else:
|
||||
"LTXImageToVideoPipeline",
|
||||
"LTXConditionPipeline",
|
||||
"LTXLatentUpsamplePipeline",
|
||||
"LTXI2VLongMultiPromptPipeline",
|
||||
]
|
||||
_import_structure["ltx2"] = ["LTX2Pipeline", "LTX2ImageToVideoPipeline", "LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["lumina"] = ["LuminaPipeline", "LuminaText2ImgPipeline"]
|
||||
_import_structure["lumina2"] = ["Lumina2Pipeline", "Lumina2Text2ImgPipeline"]
|
||||
_import_structure["lucy"] = ["LucyEditPipeline"]
|
||||
@@ -729,7 +731,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import LTXConditionPipeline, LTXImageToVideoPipeline, LTXLatentUpsamplePipeline, LTXPipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
LTXI2VLongMultiPromptPipeline,
|
||||
LTXImageToVideoPipeline,
|
||||
LTXLatentUpsamplePipeline,
|
||||
LTXPipeline,
|
||||
)
|
||||
from .ltx2 import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
|
||||
from .lucy import LucyEditPipeline
|
||||
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
|
||||
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
|
||||
|
||||
@@ -99,6 +99,7 @@ from .qwenimage import (
|
||||
QwenImageEditPlusPipeline,
|
||||
QwenImageImg2ImgPipeline,
|
||||
QwenImageInpaintPipeline,
|
||||
QwenImageLayeredPipeline,
|
||||
QwenImagePipeline,
|
||||
)
|
||||
from .sana import SanaPipeline
|
||||
@@ -202,6 +203,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("qwenimage", QwenImageImg2ImgPipeline),
|
||||
("qwenimage-edit", QwenImageEditPipeline),
|
||||
("qwenimage-edit-plus", QwenImageEditPlusPipeline),
|
||||
("qwenimage-layered", QwenImageLayeredPipeline),
|
||||
("z-image", ZImageImg2ImgPipeline),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ else:
|
||||
_import_structure["modeling_latent_upsampler"] = ["LTXLatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx"] = ["LTXPipeline"]
|
||||
_import_structure["pipeline_ltx_condition"] = ["LTXConditionPipeline"]
|
||||
_import_structure["pipeline_ltx_i2v_long_multi_prompt"] = ["LTXI2VLongMultiPromptPipeline"]
|
||||
_import_structure["pipeline_ltx_image2video"] = ["LTXImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx_latent_upsample"] = ["LTXLatentUpsamplePipeline"]
|
||||
|
||||
@@ -39,6 +40,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .modeling_latent_upsampler import LTXLatentUpsamplerModel
|
||||
from .pipeline_ltx import LTXPipeline
|
||||
from .pipeline_ltx_condition import LTXConditionPipeline
|
||||
from .pipeline_ltx_i2v_long_multi_prompt import LTXI2VLongMultiPromptPipeline
|
||||
from .pipeline_ltx_image2video import LTXImageToVideoPipeline
|
||||
from .pipeline_ltx_latent_upsample import LTXLatentUpsamplePipeline
|
||||
|
||||
|
||||
1408
src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
Normal file
1408
src/diffusers/pipelines/ltx/pipeline_ltx_i2v_long_multi_prompt.py
Normal file
File diff suppressed because it is too large
Load Diff
58
src/diffusers/pipelines/ltx2/__init__.py
Normal file
58
src/diffusers/pipelines/ltx2/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["connectors"] = ["LTX2TextConnectors"]
|
||||
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
|
||||
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
|
||||
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
|
||||
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
|
||||
_import_structure["vocoder"] = ["LTX2Vocoder"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .connectors import LTX2TextConnectors
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
from .pipeline_ltx2 import LTX2Pipeline
|
||||
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
|
||||
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
|
||||
from .vocoder import LTX2Vocoder
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
325
src/diffusers/pipelines/ltx2/connectors.py
Normal file
@@ -0,0 +1,325 @@
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.attention import FeedForward
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...models.transformers.transformer_ltx2 import LTX2Attention, LTX2AudioVideoAttnProcessor
|
||||
|
||||
|
||||
class LTX2RotaryPosEmbed1d(nn.Module):
|
||||
"""
|
||||
1D rotary positional embeddings (RoPE) for the LTX 2.0 text encoder connectors.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
base_seq_len: int = 4096,
|
||||
theta: float = 10000.0,
|
||||
double_precision: bool = True,
|
||||
rope_type: str = "interleaved",
|
||||
num_attention_heads: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
if rope_type not in ["interleaved", "split"]:
|
||||
raise ValueError(f"{rope_type=} not supported. Choose between 'interleaved' and 'split'.")
|
||||
|
||||
self.dim = dim
|
||||
self.base_seq_len = base_seq_len
|
||||
self.theta = theta
|
||||
self.double_precision = double_precision
|
||||
self.rope_type = rope_type
|
||||
self.num_attention_heads = num_attention_heads
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch_size: int,
|
||||
pos: int,
|
||||
device: Union[str, torch.device],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# 1. Get 1D position ids
|
||||
grid_1d = torch.arange(pos, dtype=torch.float32, device=device)
|
||||
# Get fractional indices relative to self.base_seq_len
|
||||
grid_1d = grid_1d / self.base_seq_len
|
||||
grid = grid_1d.unsqueeze(0).repeat(batch_size, 1) # [batch_size, seq_len]
|
||||
|
||||
# 2. Calculate 1D RoPE frequencies
|
||||
num_rope_elems = 2 # 1 (because 1D) * 2 (for cos, sin) = 2
|
||||
freqs_dtype = torch.float64 if self.double_precision else torch.float32
|
||||
pow_indices = torch.pow(
|
||||
self.theta,
|
||||
torch.linspace(start=0.0, end=1.0, steps=self.dim // num_rope_elems, dtype=freqs_dtype, device=device),
|
||||
)
|
||||
freqs = (pow_indices * torch.pi / 2.0).to(dtype=torch.float32)
|
||||
|
||||
# 3. Matrix-vector outer product between pos ids of shape (batch_size, seq_len) and freqs vector of shape
|
||||
# (self.dim // 2,).
|
||||
freqs = (grid.unsqueeze(-1) * 2 - 1) * freqs # [B, seq_len, self.dim // 2]
|
||||
|
||||
# 4. Get real, interleaved (cos, sin) frequencies, padded to self.dim
|
||||
if self.rope_type == "interleaved":
|
||||
cos_freqs = freqs.cos().repeat_interleave(2, dim=-1)
|
||||
sin_freqs = freqs.sin().repeat_interleave(2, dim=-1)
|
||||
|
||||
if self.dim % num_rope_elems != 0:
|
||||
cos_padding = torch.ones_like(cos_freqs[:, :, : self.dim % num_rope_elems])
|
||||
sin_padding = torch.zeros_like(sin_freqs[:, :, : self.dim % num_rope_elems])
|
||||
cos_freqs = torch.cat([cos_padding, cos_freqs], dim=-1)
|
||||
sin_freqs = torch.cat([sin_padding, sin_freqs], dim=-1)
|
||||
|
||||
elif self.rope_type == "split":
|
||||
expected_freqs = self.dim // 2
|
||||
current_freqs = freqs.shape[-1]
|
||||
pad_size = expected_freqs - current_freqs
|
||||
cos_freq = freqs.cos()
|
||||
sin_freq = freqs.sin()
|
||||
|
||||
if pad_size != 0:
|
||||
cos_padding = torch.ones_like(cos_freq[:, :, :pad_size])
|
||||
sin_padding = torch.zeros_like(sin_freq[:, :, :pad_size])
|
||||
|
||||
cos_freq = torch.concatenate([cos_padding, cos_freq], axis=-1)
|
||||
sin_freq = torch.concatenate([sin_padding, sin_freq], axis=-1)
|
||||
|
||||
# Reshape freqs to be compatible with multi-head attention
|
||||
b = cos_freq.shape[0]
|
||||
t = cos_freq.shape[1]
|
||||
|
||||
cos_freq = cos_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
sin_freq = sin_freq.reshape(b, t, self.num_attention_heads, -1)
|
||||
|
||||
cos_freqs = torch.swapaxes(cos_freq, 1, 2) # (B,H,T,D//2)
|
||||
sin_freqs = torch.swapaxes(sin_freq, 1, 2) # (B,H,T,D//2)
|
||||
|
||||
return cos_freqs, sin_freqs
|
||||
|
||||
|
||||
class LTX2TransformerBlock1d(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
activation_fn: str = "gelu-approximate",
|
||||
eps: float = 1e-6,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.attn1 = LTX2Attention(
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
kv_heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
processor=LTX2AudioVideoAttnProcessor(),
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
self.norm2 = torch.nn.RMSNorm(dim, eps=eps, elementwise_affine=False)
|
||||
self.ff = FeedForward(dim, activation_fn=activation_fn)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
rotary_emb: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
attn_hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, query_rotary_emb=rotary_emb)
|
||||
hidden_states = hidden_states + attn_hidden_states
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
ff_hidden_states = self.ff(norm_hidden_states)
|
||||
hidden_states = hidden_states + ff_hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LTX2ConnectorTransformer1d(nn.Module):
|
||||
"""
|
||||
A 1D sequence transformer for modalities such as text.
|
||||
|
||||
In LTX 2.0, this is used to process the text encoder hidden states for each of the video and audio streams.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 30,
|
||||
attention_head_dim: int = 128,
|
||||
num_layers: int = 2,
|
||||
num_learnable_registers: int | None = 128,
|
||||
rope_base_seq_len: int = 4096,
|
||||
rope_theta: float = 10000.0,
|
||||
rope_double_precision: bool = True,
|
||||
eps: float = 1e-6,
|
||||
causal_temporal_positioning: bool = False,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.inner_dim = num_attention_heads * attention_head_dim
|
||||
self.causal_temporal_positioning = causal_temporal_positioning
|
||||
|
||||
self.num_learnable_registers = num_learnable_registers
|
||||
self.learnable_registers = None
|
||||
if num_learnable_registers is not None:
|
||||
init_registers = torch.rand(num_learnable_registers, self.inner_dim) * 2.0 - 1.0
|
||||
self.learnable_registers = torch.nn.Parameter(init_registers)
|
||||
|
||||
self.rope = LTX2RotaryPosEmbed1d(
|
||||
self.inner_dim,
|
||||
base_seq_len=rope_base_seq_len,
|
||||
theta=rope_theta,
|
||||
double_precision=rope_double_precision,
|
||||
rope_type=rope_type,
|
||||
num_attention_heads=num_attention_heads,
|
||||
)
|
||||
|
||||
self.transformer_blocks = torch.nn.ModuleList(
|
||||
[
|
||||
LTX2TransformerBlock1d(
|
||||
dim=self.inner_dim,
|
||||
num_attention_heads=num_attention_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = torch.nn.RMSNorm(self.inner_dim, eps=eps, elementwise_affine=False)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attn_mask_binarize_threshold: float = -9000.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
# hidden_states shape: [batch_size, seq_len, hidden_dim]
|
||||
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
|
||||
batch_size, seq_len, _ = hidden_states.shape
|
||||
|
||||
# 1. Replace padding with learned registers, if using
|
||||
if self.learnable_registers is not None:
|
||||
if seq_len % self.num_learnable_registers != 0:
|
||||
raise ValueError(
|
||||
f"The `hidden_states` sequence length {hidden_states.shape[1]} should be divisible by the number"
|
||||
f" of learnable registers {self.num_learnable_registers}"
|
||||
)
|
||||
|
||||
num_register_repeats = seq_len // self.num_learnable_registers
|
||||
registers = torch.tile(self.learnable_registers, (num_register_repeats, 1)) # [seq_len, inner_dim]
|
||||
|
||||
binary_attn_mask = (attention_mask >= attn_mask_binarize_threshold).int()
|
||||
if binary_attn_mask.ndim == 4:
|
||||
binary_attn_mask = binary_attn_mask.squeeze(1).squeeze(1) # [B, 1, 1, L] --> [B, L]
|
||||
|
||||
hidden_states_non_padded = [hidden_states[i, binary_attn_mask[i].bool(), :] for i in range(batch_size)]
|
||||
valid_seq_lens = [x.shape[0] for x in hidden_states_non_padded]
|
||||
pad_lengths = [seq_len - valid_seq_len for valid_seq_len in valid_seq_lens]
|
||||
padded_hidden_states = [
|
||||
F.pad(x, pad=(0, 0, 0, p), value=0) for x, p in zip(hidden_states_non_padded, pad_lengths)
|
||||
]
|
||||
padded_hidden_states = torch.cat([x.unsqueeze(0) for x in padded_hidden_states], dim=0) # [B, L, D]
|
||||
|
||||
flipped_mask = torch.flip(binary_attn_mask, dims=[1]).unsqueeze(-1) # [B, L, 1]
|
||||
hidden_states = flipped_mask * padded_hidden_states + (1 - flipped_mask) * registers
|
||||
|
||||
# Overwrite attention_mask with an all-zeros mask if using registers.
|
||||
attention_mask = torch.zeros_like(attention_mask)
|
||||
|
||||
# 2. Calculate 1D RoPE positional embeddings
|
||||
rotary_emb = self.rope(batch_size, seq_len, device=hidden_states.device)
|
||||
|
||||
# 3. Run 1D transformer blocks
|
||||
for block in self.transformer_blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
hidden_states = self._gradient_checkpointing_func(block, hidden_states, attention_mask, rotary_emb)
|
||||
else:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask, rotary_emb=rotary_emb)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
|
||||
return hidden_states, attention_mask
|
||||
|
||||
|
||||
class LTX2TextConnectors(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Text connector stack used by LTX 2.0 to process the packed text encoder hidden states for both the video and audio
|
||||
streams.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
caption_channels: int,
|
||||
text_proj_in_factor: int,
|
||||
video_connector_num_attention_heads: int,
|
||||
video_connector_attention_head_dim: int,
|
||||
video_connector_num_layers: int,
|
||||
video_connector_num_learnable_registers: int | None,
|
||||
audio_connector_num_attention_heads: int,
|
||||
audio_connector_attention_head_dim: int,
|
||||
audio_connector_num_layers: int,
|
||||
audio_connector_num_learnable_registers: int | None,
|
||||
connector_rope_base_seq_len: int,
|
||||
rope_theta: float,
|
||||
rope_double_precision: bool,
|
||||
causal_temporal_positioning: bool,
|
||||
rope_type: str = "interleaved",
|
||||
):
|
||||
super().__init__()
|
||||
self.text_proj_in = nn.Linear(caption_channels * text_proj_in_factor, caption_channels, bias=False)
|
||||
self.video_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=video_connector_num_attention_heads,
|
||||
attention_head_dim=video_connector_attention_head_dim,
|
||||
num_layers=video_connector_num_layers,
|
||||
num_learnable_registers=video_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
self.audio_connector = LTX2ConnectorTransformer1d(
|
||||
num_attention_heads=audio_connector_num_attention_heads,
|
||||
attention_head_dim=audio_connector_attention_head_dim,
|
||||
num_layers=audio_connector_num_layers,
|
||||
num_learnable_registers=audio_connector_num_learnable_registers,
|
||||
rope_base_seq_len=connector_rope_base_seq_len,
|
||||
rope_theta=rope_theta,
|
||||
rope_double_precision=rope_double_precision,
|
||||
causal_temporal_positioning=causal_temporal_positioning,
|
||||
rope_type=rope_type,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, text_encoder_hidden_states: torch.Tensor, attention_mask: torch.Tensor, additive_mask: bool = False
|
||||
):
|
||||
# Convert to additive attention mask, if necessary
|
||||
if not additive_mask:
|
||||
text_dtype = text_encoder_hidden_states.dtype
|
||||
attention_mask = (attention_mask - 1).reshape(attention_mask.shape[0], 1, -1, attention_mask.shape[-1])
|
||||
attention_mask = attention_mask.to(text_dtype) * torch.finfo(text_dtype).max
|
||||
|
||||
text_encoder_hidden_states = self.text_proj_in(text_encoder_hidden_states)
|
||||
|
||||
video_text_embedding, new_attn_mask = self.video_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
attn_mask = (new_attn_mask < 1e-6).to(torch.int64)
|
||||
attn_mask = attn_mask.reshape(video_text_embedding.shape[0], video_text_embedding.shape[1], 1)
|
||||
video_text_embedding = video_text_embedding * attn_mask
|
||||
new_attn_mask = attn_mask.squeeze(-1)
|
||||
|
||||
audio_text_embedding, _ = self.audio_connector(text_encoder_hidden_states, attention_mask)
|
||||
|
||||
return video_text_embedding, audio_text_embedding, new_attn_mask
|
||||
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
134
src/diffusers/pipelines/ltx2/export_utils.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The Lightricks team and The HuggingFace Team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fractions import Fraction
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ...utils import is_av_available
|
||||
|
||||
|
||||
_CAN_USE_AV = is_av_available()
|
||||
if _CAN_USE_AV:
|
||||
import av
|
||||
else:
|
||||
raise ImportError(
|
||||
"PyAV is required to use LTX 2.0 video export utilities. You can install it with `pip install av`"
|
||||
)
|
||||
|
||||
|
||||
def _prepare_audio_stream(container: av.container.Container, audio_sample_rate: int) -> av.audio.AudioStream:
|
||||
"""
|
||||
Prepare the audio stream for writing.
|
||||
"""
|
||||
audio_stream = container.add_stream("aac", rate=audio_sample_rate)
|
||||
audio_stream.codec_context.sample_rate = audio_sample_rate
|
||||
audio_stream.codec_context.layout = "stereo"
|
||||
audio_stream.codec_context.time_base = Fraction(1, audio_sample_rate)
|
||||
return audio_stream
|
||||
|
||||
|
||||
def _resample_audio(
|
||||
container: av.container.Container, audio_stream: av.audio.AudioStream, frame_in: av.AudioFrame
|
||||
) -> None:
|
||||
cc = audio_stream.codec_context
|
||||
|
||||
# Use the encoder's format/layout/rate as the *target*
|
||||
target_format = cc.format or "fltp" # AAC → usually fltp
|
||||
target_layout = cc.layout or "stereo"
|
||||
target_rate = cc.sample_rate or frame_in.sample_rate
|
||||
|
||||
audio_resampler = av.audio.resampler.AudioResampler(
|
||||
format=target_format,
|
||||
layout=target_layout,
|
||||
rate=target_rate,
|
||||
)
|
||||
|
||||
audio_next_pts = 0
|
||||
for rframe in audio_resampler.resample(frame_in):
|
||||
if rframe.pts is None:
|
||||
rframe.pts = audio_next_pts
|
||||
audio_next_pts += rframe.samples
|
||||
rframe.sample_rate = frame_in.sample_rate
|
||||
container.mux(audio_stream.encode(rframe))
|
||||
|
||||
# flush audio encoder
|
||||
for packet in audio_stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def _write_audio(
|
||||
container: av.container.Container,
|
||||
audio_stream: av.audio.AudioStream,
|
||||
samples: torch.Tensor,
|
||||
audio_sample_rate: int,
|
||||
) -> None:
|
||||
if samples.ndim == 1:
|
||||
samples = samples[:, None]
|
||||
|
||||
if samples.shape[1] != 2 and samples.shape[0] == 2:
|
||||
samples = samples.T
|
||||
|
||||
if samples.shape[1] != 2:
|
||||
raise ValueError(f"Expected samples with 2 channels; got shape {samples.shape}.")
|
||||
|
||||
# Convert to int16 packed for ingestion; resampler converts to encoder fmt.
|
||||
if samples.dtype != torch.int16:
|
||||
samples = torch.clip(samples, -1.0, 1.0)
|
||||
samples = (samples * 32767.0).to(torch.int16)
|
||||
|
||||
frame_in = av.AudioFrame.from_ndarray(
|
||||
samples.contiguous().reshape(1, -1).cpu().numpy(),
|
||||
format="s16",
|
||||
layout="stereo",
|
||||
)
|
||||
frame_in.sample_rate = audio_sample_rate
|
||||
|
||||
_resample_audio(container, audio_stream, frame_in)
|
||||
|
||||
|
||||
def encode_video(
|
||||
video: torch.Tensor, fps: int, audio: Optional[torch.Tensor], audio_sample_rate: Optional[int], output_path: str
|
||||
) -> None:
|
||||
video_np = video.cpu().numpy()
|
||||
|
||||
_, height, width, _ = video_np.shape
|
||||
|
||||
container = av.open(output_path, mode="w")
|
||||
stream = container.add_stream("libx264", rate=int(fps))
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
|
||||
if audio is not None:
|
||||
if audio_sample_rate is None:
|
||||
raise ValueError("audio_sample_rate is required when audio is provided")
|
||||
|
||||
audio_stream = _prepare_audio_stream(container, audio_sample_rate)
|
||||
|
||||
for frame_array in video_np:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
|
||||
# Flush encoder
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
if audio is not None:
|
||||
_write_audio(container, audio_stream, audio, audio_sample_rate)
|
||||
|
||||
container.close()
|
||||
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
285
src/diffusers/pipelines/ltx2/latent_upsampler.py
Normal file
@@ -0,0 +1,285 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
RATIONAL_RESAMPLER_SCALE_MAPPING = {
|
||||
0.75: (3, 4),
|
||||
1.5: (3, 2),
|
||||
2.0: (2, 1),
|
||||
4.0: (4, 1),
|
||||
}
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.ResBlock
|
||||
class ResBlock(torch.nn.Module):
|
||||
def __init__(self, channels: int, mid_channels: Optional[int] = None, dims: int = 3):
|
||||
super().__init__()
|
||||
if mid_channels is None:
|
||||
mid_channels = channels
|
||||
|
||||
Conv = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.conv1 = Conv(channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.norm1 = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.conv2 = Conv(mid_channels, channels, kernel_size=3, padding=1)
|
||||
self.norm2 = torch.nn.GroupNorm(32, channels)
|
||||
self.activation = torch.nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.activation(hidden_states + residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.ltx.modeling_latent_upsampler.PixelShuffleND
|
||||
class PixelShuffleND(torch.nn.Module):
|
||||
def __init__(self, dims, upscale_factors=(2, 2, 2)):
|
||||
super().__init__()
|
||||
|
||||
self.dims = dims
|
||||
self.upscale_factors = upscale_factors
|
||||
|
||||
if dims not in [1, 2, 3]:
|
||||
raise ValueError("dims must be 1, 2, or 3")
|
||||
|
||||
def forward(self, x):
|
||||
if self.dims == 3:
|
||||
# spatiotemporal: b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:3]))
|
||||
.permute(0, 1, 5, 2, 6, 3, 7, 4)
|
||||
.flatten(6, 7)
|
||||
.flatten(4, 5)
|
||||
.flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 2:
|
||||
# spatial: b (c p1 p2) h w -> b c (h p1) (w p2)
|
||||
return (
|
||||
x.unflatten(1, (-1, *self.upscale_factors[:2])).permute(0, 1, 4, 2, 5, 3).flatten(4, 5).flatten(2, 3)
|
||||
)
|
||||
elif self.dims == 1:
|
||||
# temporal: b (c p1) f h w -> b c (f p1) h w
|
||||
return x.unflatten(1, (-1, *self.upscale_factors[:1])).permute(0, 1, 3, 2, 4, 5).flatten(2, 3)
|
||||
|
||||
|
||||
class BlurDownsample(torch.nn.Module):
|
||||
"""
|
||||
Anti-aliased spatial downsampling by integer stride using a fixed separable binomial kernel. Applies only on H,W.
|
||||
Works for dims=2 or dims=3 (per-frame).
|
||||
"""
|
||||
|
||||
def __init__(self, dims: int, stride: int, kernel_size: int = 5) -> None:
|
||||
super().__init__()
|
||||
|
||||
if dims not in (2, 3):
|
||||
raise ValueError(f"`dims` must be either 2 or 3 but is {dims}")
|
||||
if kernel_size < 3 or kernel_size % 2 != 1:
|
||||
raise ValueError(f"`kernel_size` must be an odd number >= 3 but is {kernel_size}")
|
||||
|
||||
self.dims = dims
|
||||
self.stride = stride
|
||||
self.kernel_size = kernel_size
|
||||
|
||||
# 5x5 separable binomial kernel using binomial coefficients [1, 4, 6, 4, 1] from
|
||||
# the 4th row of Pascal's triangle. This kernel is used for anti-aliasing and
|
||||
# provides a smooth approximation of a Gaussian filter (often called a "binomial filter").
|
||||
# The 2D kernel is constructed as the outer product and normalized.
|
||||
k = torch.tensor([math.comb(kernel_size - 1, k) for k in range(kernel_size)])
|
||||
k2d = k[:, None] @ k[None, :]
|
||||
k2d = (k2d / k2d.sum()).float() # shape (kernel_size, kernel_size)
|
||||
self.register_buffer("kernel", k2d[None, None, :, :]) # (1, 1, kernel_size, kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.stride == 1:
|
||||
return x
|
||||
|
||||
if self.dims == 2:
|
||||
c = x.shape[1]
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
else:
|
||||
# dims == 3: apply per-frame on H,W
|
||||
b, c, f, _, _ = x.shape
|
||||
x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
|
||||
weight = self.kernel.expand(c, 1, self.kernel_size, self.kernel_size) # depthwise
|
||||
x = F.conv2d(x, weight=weight, bias=None, stride=self.stride, padding=self.kernel_size // 2, groups=c)
|
||||
|
||||
h2, w2 = x.shape[-2:]
|
||||
x = x.unflatten(0, (b, f)).reshape(b, -1, f, h2, w2) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class SpatialRationalResampler(torch.nn.Module):
|
||||
"""
|
||||
Scales by the spatial size of the input by a rational number `scale`. For example, `scale = 0.75` will downsample
|
||||
by a factor of 3 / 4, while `scale = 1.5` will upsample by a factor of 3 / 2. This works by first upsampling the
|
||||
input by the (integer) numerator of `scale`, and then performing a blur + stride anti-aliased downsample by the
|
||||
(integer) denominator.
|
||||
"""
|
||||
|
||||
def __init__(self, mid_channels: int = 1024, scale: float = 2.0):
|
||||
super().__init__()
|
||||
self.scale = float(scale)
|
||||
num_denom = RATIONAL_RESAMPLER_SCALE_MAPPING.get(scale, None)
|
||||
if num_denom is None:
|
||||
raise ValueError(
|
||||
f"The supplied `scale` {scale} is not supported; supported scales are {list(RATIONAL_RESAMPLER_SCALE_MAPPING.keys())}"
|
||||
)
|
||||
self.num, self.den = num_denom
|
||||
|
||||
self.conv = torch.nn.Conv2d(mid_channels, (self.num**2) * mid_channels, kernel_size=3, padding=1)
|
||||
self.pixel_shuffle = PixelShuffleND(2, upscale_factors=(self.num, self.num))
|
||||
self.blur_down = BlurDownsample(dims=2, stride=self.den)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Expected x shape: [B * F, C, H, W]
|
||||
# b, _, f, h, w = x.shape
|
||||
# x = x.transpose(1, 2).flatten(0, 1) # [B, C, F, H, W] --> [B * F, C, H, W]
|
||||
x = self.conv(x)
|
||||
x = self.pixel_shuffle(x)
|
||||
x = self.blur_down(x)
|
||||
# x = x.unflatten(0, (b, f)).reshape(b, -1, f, h, w) # [B * F, C, H, W] --> [B, C, F, H, W]
|
||||
return x
|
||||
|
||||
|
||||
class LTX2LatentUpsamplerModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Model to spatially upsample VAE latents.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, defaults to `128`):
|
||||
Number of channels in the input latent
|
||||
mid_channels (`int`, defaults to `512`):
|
||||
Number of channels in the middle layers
|
||||
num_blocks_per_stage (`int`, defaults to `4`):
|
||||
Number of ResBlocks to use in each stage (pre/post upsampling)
|
||||
dims (`int`, defaults to `3`):
|
||||
Number of dimensions for convolutions (2 or 3)
|
||||
spatial_upsample (`bool`, defaults to `True`):
|
||||
Whether to spatially upsample the latent
|
||||
temporal_upsample (`bool`, defaults to `False`):
|
||||
Whether to temporally upsample the latent
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
mid_channels: int = 1024,
|
||||
num_blocks_per_stage: int = 4,
|
||||
dims: int = 3,
|
||||
spatial_upsample: bool = True,
|
||||
temporal_upsample: bool = False,
|
||||
rational_spatial_scale: Optional[float] = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.mid_channels = mid_channels
|
||||
self.num_blocks_per_stage = num_blocks_per_stage
|
||||
self.dims = dims
|
||||
self.spatial_upsample = spatial_upsample
|
||||
self.temporal_upsample = temporal_upsample
|
||||
|
||||
ConvNd = torch.nn.Conv2d if dims == 2 else torch.nn.Conv3d
|
||||
|
||||
self.initial_conv = ConvNd(in_channels, mid_channels, kernel_size=3, padding=1)
|
||||
self.initial_norm = torch.nn.GroupNorm(32, mid_channels)
|
||||
self.initial_activation = torch.nn.SiLU()
|
||||
|
||||
self.res_blocks = torch.nn.ModuleList([ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)])
|
||||
|
||||
if spatial_upsample and temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 8 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(3),
|
||||
)
|
||||
elif spatial_upsample:
|
||||
if rational_spatial_scale is not None:
|
||||
self.upsampler = SpatialRationalResampler(mid_channels=mid_channels, scale=rational_spatial_scale)
|
||||
else:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(mid_channels, 4 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(2),
|
||||
)
|
||||
elif temporal_upsample:
|
||||
self.upsampler = torch.nn.Sequential(
|
||||
torch.nn.Conv3d(mid_channels, 2 * mid_channels, kernel_size=3, padding=1),
|
||||
PixelShuffleND(1),
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either spatial_upsample or temporal_upsample must be True")
|
||||
|
||||
self.post_upsample_res_blocks = torch.nn.ModuleList(
|
||||
[ResBlock(mid_channels, dims=dims) for _ in range(num_blocks_per_stage)]
|
||||
)
|
||||
|
||||
self.final_conv = ConvNd(mid_channels, in_channels, kernel_size=3, padding=1)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
||||
|
||||
if self.dims == 2:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
else:
|
||||
hidden_states = self.initial_conv(hidden_states)
|
||||
hidden_states = self.initial_norm(hidden_states)
|
||||
hidden_states = self.initial_activation(hidden_states)
|
||||
|
||||
for block in self.res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
if self.temporal_upsample:
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states[:, :, 1:, :, :]
|
||||
else:
|
||||
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
|
||||
hidden_states = self.upsampler(hidden_states)
|
||||
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).permute(0, 2, 1, 3, 4)
|
||||
|
||||
for block in self.post_upsample_res_blocks:
|
||||
hidden_states = block(hidden_states)
|
||||
|
||||
hidden_states = self.final_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
1141
src/diffusers/pipelines/ltx2/pipeline_ltx2.py
Normal file
File diff suppressed because it is too large
Load Diff
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
1238
src/diffusers/pipelines/ltx2/pipeline_ltx2_image2video.py
Normal file
File diff suppressed because it is too large
Load Diff
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
442
src/diffusers/pipelines/ltx2/pipeline_ltx2_latent_upsample.py
Normal file
@@ -0,0 +1,442 @@
|
||||
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...image_processor import PipelineImageInput
|
||||
from ...models import AutoencoderKLLTX2Video
|
||||
from ...utils import get_logger, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ...video_processor import VideoProcessor
|
||||
from ..ltx.pipeline_output import LTXPipelineOutput
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .latent_upsampler import LTX2LatentUpsamplerModel
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline
|
||||
>>> from diffusers.pipelines.ltx2.export_utils import encode_video
|
||||
>>> from diffusers.pipelines.ltx2.latent_upsampler import LTX2LatentUpsamplerModel
|
||||
>>> from diffusers.utils import load_image
|
||||
|
||||
>>> pipe = LTX2ImageToVideoPipeline.from_pretrained("Lightricks/LTX-2", torch_dtype=torch.bfloat16)
|
||||
>>> pipe.enable_model_cpu_offload()
|
||||
|
||||
>>> image = load_image(
|
||||
... "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png"
|
||||
... )
|
||||
>>> prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background."
|
||||
>>> negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted"
|
||||
|
||||
>>> frame_rate = 24.0
|
||||
>>> video, audio = pipe(
|
||||
... image=image,
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... num_frames=121,
|
||||
... frame_rate=frame_rate,
|
||||
... num_inference_steps=40,
|
||||
... guidance_scale=4.0,
|
||||
... output_type="pil",
|
||||
... return_dict=False,
|
||||
... )
|
||||
|
||||
>>> latent_upsampler = LTX2LatentUpsamplerModel.from_pretrained(
|
||||
... "Lightricks/LTX-2", subfolder="latent_upsampler", torch_dtype=torch.bfloat16
|
||||
... )
|
||||
>>> upsample_pipe = LTX2LatentUpsamplePipeline(vae=pipe.vae, latent_upsampler=latent_upsampler)
|
||||
>>> upsample_pipe.vae.enable_tiling()
|
||||
>>> upsample_pipe.to(device="cuda", dtype=torch.bfloat16)
|
||||
|
||||
>>> video = upsample_pipe(
|
||||
... video=video,
|
||||
... width=768,
|
||||
... height=512,
|
||||
... output_type="np",
|
||||
... return_dict=False,
|
||||
... )[0]
|
||||
>>> video = (video * 255).round().astype("uint8")
|
||||
>>> video = torch.from_numpy(video)
|
||||
|
||||
>>> encode_video(
|
||||
... video[0],
|
||||
... fps=frame_rate,
|
||||
... audio=audio[0].float().cpu(),
|
||||
... audio_sample_rate=pipe.vocoder.config.output_sampling_rate, # should be 24000
|
||||
... output_path="video.mp4",
|
||||
... )
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
|
||||
def retrieve_latents(
|
||||
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
|
||||
):
|
||||
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
|
||||
return encoder_output.latent_dist.sample(generator)
|
||||
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
|
||||
return encoder_output.latent_dist.mode()
|
||||
elif hasattr(encoder_output, "latents"):
|
||||
return encoder_output.latents
|
||||
else:
|
||||
raise AttributeError("Could not access latents of provided encoder_output")
|
||||
|
||||
|
||||
class LTX2LatentUpsamplePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "vae->latent_upsampler"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKLLTX2Video,
|
||||
latent_upsampler: LTX2LatentUpsamplerModel,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, latent_upsampler=latent_upsampler)
|
||||
|
||||
self.vae_spatial_compression_ratio = (
|
||||
self.vae.spatial_compression_ratio if getattr(self, "vae", None) is not None else 32
|
||||
)
|
||||
self.vae_temporal_compression_ratio = (
|
||||
self.vae.temporal_compression_ratio if getattr(self, "vae", None) is not None else 8
|
||||
)
|
||||
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
video: Optional[torch.Tensor] = None,
|
||||
batch_size: int = 1,
|
||||
num_frames: int = 121,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if latents is not None:
|
||||
if latents.ndim == 3:
|
||||
# Convert token seq [B, S, D] to latent video [B, C, F, H, W]
|
||||
latent_num_frames = (num_frames - 1) // self.vae_temporal_compression_ratio + 1
|
||||
latent_height = height // self.vae_spatial_compression_ratio
|
||||
latent_width = width // self.vae_spatial_compression_ratio
|
||||
latents = self._unpack_latents(
|
||||
latents, latent_num_frames, latent_height, latent_width, spatial_patch_size, temporal_patch_size
|
||||
)
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
|
||||
video = video.to(device=device, dtype=self.vae.dtype)
|
||||
if isinstance(generator, list):
|
||||
if len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
||||
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
||||
)
|
||||
|
||||
init_latents = [
|
||||
retrieve_latents(self.vae.encode(video[i].unsqueeze(0)), generator[i]) for i in range(batch_size)
|
||||
]
|
||||
else:
|
||||
init_latents = [retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator) for vid in video]
|
||||
|
||||
init_latents = torch.cat(init_latents, dim=0).to(dtype)
|
||||
# NOTE: latent upsampler operates on the unnormalized latents, so don't normalize here
|
||||
# init_latents = self._normalize_latents(init_latents, self.vae.latents_mean, self.vae.latents_std)
|
||||
return init_latents
|
||||
|
||||
def adain_filter_latent(self, latents: torch.Tensor, reference_latents: torch.Tensor, factor: float = 1.0):
|
||||
"""
|
||||
Applies Adaptive Instance Normalization (AdaIN) to a latent tensor based on statistics from a reference latent
|
||||
tensor.
|
||||
|
||||
Args:
|
||||
latent (`torch.Tensor`):
|
||||
Input latents to normalize
|
||||
reference_latents (`torch.Tensor`):
|
||||
The reference latents providing style statistics.
|
||||
factor (`float`):
|
||||
Blending factor between original and transformed latent. Range: -10.0 to 10.0, Default: 1.0
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The transformed latent tensor
|
||||
"""
|
||||
result = latents.clone()
|
||||
|
||||
for i in range(latents.size(0)):
|
||||
for c in range(latents.size(1)):
|
||||
r_sd, r_mean = torch.std_mean(reference_latents[i, c], dim=None) # index by original dim order
|
||||
i_sd, i_mean = torch.std_mean(result[i, c], dim=None)
|
||||
|
||||
result[i, c] = ((result[i, c] - i_mean) / i_sd) * r_sd + r_mean
|
||||
|
||||
result = torch.lerp(latents, result, factor)
|
||||
return result
|
||||
|
||||
def tone_map_latents(self, latents: torch.Tensor, compression: float) -> torch.Tensor:
|
||||
"""
|
||||
Applies a non-linear tone-mapping function to latent values to reduce their dynamic range in a perceptually
|
||||
smooth way using a sigmoid-based compression.
|
||||
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation, especially
|
||||
when controlling dynamic behavior with a `compression` factor.
|
||||
|
||||
Args:
|
||||
latents : torch.Tensor
|
||||
Input latent tensor with arbitrary shape. Expected to be roughly in [-1, 1] or [0, 1] range.
|
||||
compression : float
|
||||
Compression strength in the range [0, 1].
|
||||
- 0.0: No tone-mapping (identity transform)
|
||||
- 1.0: Full compression effect
|
||||
|
||||
Returns:
|
||||
torch.Tensor
|
||||
The tone-mapped latent tensor of the same shape as input.
|
||||
"""
|
||||
# Remap [0-1] to [0-0.75] and apply sigmoid compression in one shot
|
||||
scale_factor = compression * 0.75
|
||||
abs_latents = torch.abs(latents)
|
||||
|
||||
# Sigmoid compression: sigmoid shifts large values toward 0.2, small values stay ~1.0
|
||||
# When scale_factor=0, sigmoid term vanishes, when scale_factor=0.75, full effect
|
||||
sigmoid_term = torch.sigmoid(4.0 * scale_factor * (abs_latents - 1.0))
|
||||
scales = 1.0 - 0.8 * scale_factor * sigmoid_term
|
||||
|
||||
filtered = latents * scales
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2_image2video.LTX2ImageToVideoPipeline._normalize_latents
|
||||
def _normalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Normalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = (latents - latents_mean) * scaling_factor / latents_std
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._denormalize_latents
|
||||
def _denormalize_latents(
|
||||
latents: torch.Tensor, latents_mean: torch.Tensor, latents_std: torch.Tensor, scaling_factor: float = 1.0
|
||||
) -> torch.Tensor:
|
||||
# Denormalize latents across the channel dimension [B, C, F, H, W]
|
||||
latents_mean = latents_mean.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents_std = latents_std.view(1, -1, 1, 1, 1).to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std / scaling_factor + latents_mean
|
||||
return latents
|
||||
|
||||
@staticmethod
|
||||
# Copied from diffusers.pipelines.ltx2.pipeline_ltx2.LTX2Pipeline._unpack_latents
|
||||
def _unpack_latents(
|
||||
latents: torch.Tensor, num_frames: int, height: int, width: int, patch_size: int = 1, patch_size_t: int = 1
|
||||
) -> torch.Tensor:
|
||||
# Packed latents of shape [B, S, D] (S is the effective video sequence length, D is the effective feature dimensions)
|
||||
# are unpacked and reshaped into a video tensor of shape [B, C, F, H, W]. This is the inverse operation of
|
||||
# what happens in the `_pack_latents` method.
|
||||
batch_size = latents.size(0)
|
||||
latents = latents.reshape(batch_size, num_frames, height, width, -1, patch_size_t, patch_size, patch_size)
|
||||
latents = latents.permute(0, 4, 1, 5, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
return latents
|
||||
|
||||
def check_inputs(self, video, height, width, latents, tone_map_compression_ratio):
|
||||
if height % self.vae_spatial_compression_ratio != 0 or width % self.vae_spatial_compression_ratio != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
|
||||
if video is not None and latents is not None:
|
||||
raise ValueError("Only one of `video` or `latents` can be provided.")
|
||||
if video is None and latents is None:
|
||||
raise ValueError("One of `video` or `latents` has to be provided.")
|
||||
|
||||
if not (0 <= tone_map_compression_ratio <= 1):
|
||||
raise ValueError("`tone_map_compression_ratio` must be in the range [0, 1]")
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
video: Optional[List[PipelineImageInput]] = None,
|
||||
height: int = 512,
|
||||
width: int = 768,
|
||||
num_frames: int = 121,
|
||||
spatial_patch_size: int = 1,
|
||||
temporal_patch_size: int = 1,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
latents_normalized: bool = False,
|
||||
decode_timestep: Union[float, List[float]] = 0.0,
|
||||
decode_noise_scale: Optional[Union[float, List[float]]] = None,
|
||||
adain_factor: float = 0.0,
|
||||
tone_map_compression_ratio: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
video (`List[PipelineImageInput]`, *optional*)
|
||||
The video to be upsampled (such as a LTX 2.0 first stage output). If not supplied, `latents` should be
|
||||
supplied.
|
||||
height (`int`, *optional*, defaults to `512`):
|
||||
The height in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
width (`int`, *optional*, defaults to `768`):
|
||||
The width in pixels of the input video (not the generated video, which will have a larger resolution).
|
||||
num_frames (`int`, *optional*, defaults to `121`):
|
||||
The number of frames in the input video.
|
||||
spatial_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The spatial patch size of the video latents. Used when `latents` is supplied if unpacking is necessary.
|
||||
temporal_patch_size (`int`, *optional*, defaults to `1`):
|
||||
The temporal patch size of the video latents. Used when `latents` is supplied if unpacking is
|
||||
necessary.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated video latents. This can be supplied in place of the `video` argument. Can either be a
|
||||
patch sequence of shape `(batch_size, seq_len, hidden_dim)` or a video latent of shape `(batch_size,
|
||||
latent_channels, latent_frames, latent_height, latent_width)`.
|
||||
latents_normalized (`bool`, *optional*, defaults to `False`)
|
||||
If `latents` are supplied, whether the `latents` are normalized using the VAE latent mean and std. If
|
||||
`True`, the `latents` will be denormalized before being supplied to the latent upsampler.
|
||||
decode_timestep (`float`, defaults to `0.0`):
|
||||
The timestep at which generated video is decoded.
|
||||
decode_noise_scale (`float`, defaults to `None`):
|
||||
The interpolation factor between random noise and denoised latents at the decode timestep.
|
||||
adain_factor (`float`, *optional*, defaults to `0.0`):
|
||||
Adaptive Instance Normalization (AdaIN) blending factor between the upsampled and original latents.
|
||||
Should be in [-10.0, 10.0]; supplying 0.0 (the default) means that AdaIN is not performed.
|
||||
tone_map_compression_ratio (`float`, *optional*, defaults to `0.0`):
|
||||
The compression strength for tone mapping, which will reduce the dynamic range of the latent values.
|
||||
This is useful for regularizing high-variance latents or for conditioning outputs during generation.
|
||||
Should be in [0, 1], where 0.0 (the default) means tone mapping is not applied and 1.0 corresponds to
|
||||
the full compression effect.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.ltx.LTXPipelineOutput`] instead of a plain tuple.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.ltx.LTXPipelineOutput`] or `tuple`:
|
||||
If `return_dict` is `True`, [`~pipelines.ltx.LTXPipelineOutput`] is returned, otherwise a `tuple` is
|
||||
returned where the first element is the upsampled video.
|
||||
"""
|
||||
|
||||
self.check_inputs(
|
||||
video=video,
|
||||
height=height,
|
||||
width=width,
|
||||
latents=latents,
|
||||
tone_map_compression_ratio=tone_map_compression_ratio,
|
||||
)
|
||||
|
||||
if video is not None:
|
||||
# Batched video input is not yet tested/supported. TODO: take a look later
|
||||
batch_size = 1
|
||||
else:
|
||||
batch_size = latents.shape[0]
|
||||
device = self._execution_device
|
||||
|
||||
if video is not None:
|
||||
num_frames = len(video)
|
||||
if num_frames % self.vae_temporal_compression_ratio != 1:
|
||||
num_frames = (
|
||||
num_frames // self.vae_temporal_compression_ratio * self.vae_temporal_compression_ratio + 1
|
||||
)
|
||||
video = video[:num_frames]
|
||||
logger.warning(
|
||||
f"Video length expected to be of the form `k * {self.vae_temporal_compression_ratio} + 1` but is {len(video)}. Truncating to {num_frames} frames."
|
||||
)
|
||||
video = self.video_processor.preprocess_video(video, height=height, width=width)
|
||||
video = video.to(device=device, dtype=torch.float32)
|
||||
|
||||
latents_supplied = latents is not None
|
||||
latents = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size,
|
||||
num_frames=num_frames,
|
||||
height=height,
|
||||
width=width,
|
||||
spatial_patch_size=spatial_patch_size,
|
||||
temporal_patch_size=temporal_patch_size,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
|
||||
if latents_supplied and latents_normalized:
|
||||
latents = self._denormalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
latents = latents.to(self.latent_upsampler.dtype)
|
||||
latents_upsampled = self.latent_upsampler(latents)
|
||||
|
||||
if adain_factor > 0.0:
|
||||
latents = self.adain_filter_latent(latents_upsampled, latents, adain_factor)
|
||||
else:
|
||||
latents = latents_upsampled
|
||||
|
||||
if tone_map_compression_ratio > 0.0:
|
||||
latents = self.tone_map_latents(latents, tone_map_compression_ratio)
|
||||
|
||||
if output_type == "latent":
|
||||
latents = self._normalize_latents(
|
||||
latents, self.vae.latents_mean, self.vae.latents_std, self.vae.config.scaling_factor
|
||||
)
|
||||
video = latents
|
||||
else:
|
||||
if not self.vae.config.timestep_conditioning:
|
||||
timestep = None
|
||||
else:
|
||||
noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=latents.dtype)
|
||||
if not isinstance(decode_timestep, list):
|
||||
decode_timestep = [decode_timestep] * batch_size
|
||||
if decode_noise_scale is None:
|
||||
decode_noise_scale = decode_timestep
|
||||
elif not isinstance(decode_noise_scale, list):
|
||||
decode_noise_scale = [decode_noise_scale] * batch_size
|
||||
|
||||
timestep = torch.tensor(decode_timestep, device=device, dtype=latents.dtype)
|
||||
decode_noise_scale = torch.tensor(decode_noise_scale, device=device, dtype=latents.dtype)[
|
||||
:, None, None, None, None
|
||||
]
|
||||
latents = (1 - decode_noise_scale) * latents + decode_noise_scale * noise
|
||||
|
||||
video = self.vae.decode(latents, timestep, return_dict=False)[0]
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (video,)
|
||||
|
||||
return LTXPipelineOutput(frames=video)
|
||||
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
23
src/diffusers/pipelines/ltx2/pipeline_output.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class LTX2PipelineOutput(BaseOutput):
|
||||
r"""
|
||||
Output class for LTX pipelines.
|
||||
|
||||
Args:
|
||||
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
||||
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
||||
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
||||
`(batch_size, num_frames, channels, height, width)`.
|
||||
audio (`torch.Tensor`, `np.ndarray`):
|
||||
TODO
|
||||
"""
|
||||
|
||||
frames: torch.Tensor
|
||||
audio: torch.Tensor
|
||||
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
159
src/diffusers/pipelines/ltx2/vocoder.py
Normal file
@@ -0,0 +1,159 @@
|
||||
import math
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
kernel_size: int = 3,
|
||||
stride: int = 1,
|
||||
dilations: Tuple[int, ...] = (1, 3, 5),
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
padding_mode: str = "same",
|
||||
):
|
||||
super().__init__()
|
||||
self.dilations = dilations
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=dilation, padding=padding_mode)
|
||||
for dilation in dilations
|
||||
]
|
||||
)
|
||||
|
||||
self.convs2 = nn.ModuleList(
|
||||
[
|
||||
nn.Conv1d(channels, channels, kernel_size, stride=stride, dilation=1, padding=padding_mode)
|
||||
for _ in range(len(dilations))
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for conv1, conv2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, negative_slope=self.negative_slope)
|
||||
xt = conv1(xt)
|
||||
xt = F.leaky_relu(xt, negative_slope=self.negative_slope)
|
||||
xt = conv2(xt)
|
||||
x = x + xt
|
||||
return x
|
||||
|
||||
|
||||
class LTX2Vocoder(ModelMixin, ConfigMixin):
|
||||
r"""
|
||||
LTX 2.0 vocoder for converting generated mel spectrograms back to audio waveforms.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 128,
|
||||
hidden_channels: int = 1024,
|
||||
out_channels: int = 2,
|
||||
upsample_kernel_sizes: List[int] = [16, 15, 8, 4, 4],
|
||||
upsample_factors: List[int] = [6, 5, 2, 2, 2],
|
||||
resnet_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resnet_dilations: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
leaky_relu_negative_slope: float = 0.1,
|
||||
output_sampling_rate: int = 24000,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_upsample_layers = len(upsample_kernel_sizes)
|
||||
self.resnets_per_upsample = len(resnet_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.total_upsample_factor = math.prod(upsample_factors)
|
||||
self.negative_slope = leaky_relu_negative_slope
|
||||
|
||||
if self.num_upsample_layers != len(upsample_factors):
|
||||
raise ValueError(
|
||||
f"`upsample_kernel_sizes` and `upsample_factors` should be lists of the same length but are length"
|
||||
f" {self.num_upsample_layers} and {len(upsample_factors)}, respectively."
|
||||
)
|
||||
|
||||
if self.resnets_per_upsample != len(resnet_dilations):
|
||||
raise ValueError(
|
||||
f"`resnet_kernel_sizes` and `resnet_dilations` should be lists of the same length but are length"
|
||||
f" {len(self.resnets_per_upsample)} and {len(resnet_dilations)}, respectively."
|
||||
)
|
||||
|
||||
self.conv_in = nn.Conv1d(in_channels, hidden_channels, kernel_size=7, stride=1, padding=3)
|
||||
|
||||
self.upsamplers = nn.ModuleList()
|
||||
self.resnets = nn.ModuleList()
|
||||
input_channels = hidden_channels
|
||||
for i, (stride, kernel_size) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
|
||||
output_channels = input_channels // 2
|
||||
self.upsamplers.append(
|
||||
nn.ConvTranspose1d(
|
||||
input_channels, # hidden_channels // (2 ** i)
|
||||
output_channels, # hidden_channels // (2 ** (i + 1))
|
||||
kernel_size,
|
||||
stride=stride,
|
||||
padding=(kernel_size - stride) // 2,
|
||||
)
|
||||
)
|
||||
|
||||
for kernel_size, dilations in zip(resnet_kernel_sizes, resnet_dilations):
|
||||
self.resnets.append(
|
||||
ResBlock(
|
||||
output_channels,
|
||||
kernel_size,
|
||||
dilations=dilations,
|
||||
leaky_relu_negative_slope=leaky_relu_negative_slope,
|
||||
)
|
||||
)
|
||||
input_channels = output_channels
|
||||
|
||||
self.conv_out = nn.Conv1d(output_channels, out_channels, 7, stride=1, padding=3)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, time_last: bool = False) -> torch.Tensor:
|
||||
r"""
|
||||
Forward pass of the vocoder.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input Mel spectrogram tensor of shape `(batch_size, num_channels, time, num_mel_bins)` if `time_last`
|
||||
is `False` (the default) or shape `(batch_size, num_channels, num_mel_bins, time)` if `time_last` is
|
||||
`True`.
|
||||
time_last (`bool`, *optional*, defaults to `False`):
|
||||
Whether the last dimension of the input is the time/frame dimension or the Mel bins dimension.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Audio waveform tensor of shape (batch_size, out_channels, audio_length)
|
||||
"""
|
||||
|
||||
# Ensure that the time/frame dimension is last
|
||||
if not time_last:
|
||||
hidden_states = hidden_states.transpose(2, 3)
|
||||
# Combine channels and frequency (mel bins) dimensions
|
||||
hidden_states = hidden_states.flatten(1, 2)
|
||||
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
for i in range(self.num_upsample_layers):
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=self.negative_slope)
|
||||
hidden_states = self.upsamplers[i](hidden_states)
|
||||
|
||||
# Run all resnets in parallel on hidden_states
|
||||
start = i * self.resnets_per_upsample
|
||||
end = (i + 1) * self.resnets_per_upsample
|
||||
resnet_outputs = torch.stack([self.resnets[j](hidden_states) for j in range(start, end)], dim=0)
|
||||
|
||||
hidden_states = torch.mean(resnet_outputs, dim=0)
|
||||
|
||||
# NOTE: unlike the first leaky ReLU, this leaky ReLU is set to use the default F.leaky_relu negative slope of
|
||||
# 0.01 (whereas the others usually use a slope of 0.1). Not sure if this is intended
|
||||
hidden_states = F.leaky_relu(hidden_states, negative_slope=0.01)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
hidden_states = torch.tanh(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
@@ -801,12 +801,6 @@ def load_sub_model(
|
||||
# add kwargs to loading method
|
||||
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
||||
loading_kwargs = {}
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
loading_kwargs["provider_options"] = provider_options
|
||||
|
||||
is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin)
|
||||
|
||||
@@ -821,6 +815,17 @@ def load_sub_model(
|
||||
and transformers_version >= version.parse("4.20.0")
|
||||
)
|
||||
|
||||
# For transformers models >= 4.56.0, use 'dtype' instead of 'torch_dtype' to avoid deprecation warnings
|
||||
if issubclass(class_obj, torch.nn.Module):
|
||||
if is_transformers_model and transformers_version >= version.parse("4.56.0"):
|
||||
loading_kwargs["dtype"] = torch_dtype
|
||||
else:
|
||||
loading_kwargs["torch_dtype"] = torch_dtype
|
||||
if issubclass(class_obj, diffusers_module.OnnxRuntimeModel):
|
||||
loading_kwargs["provider"] = provider
|
||||
loading_kwargs["sess_options"] = sess_options
|
||||
loading_kwargs["provider_options"] = provider_options
|
||||
|
||||
# When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers.
|
||||
# To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default.
|
||||
# This makes sure that the weights won't be initialized which significantly speeds up loading.
|
||||
|
||||
@@ -67,6 +67,7 @@ from ..utils import (
|
||||
logging,
|
||||
numpy_to_pil,
|
||||
)
|
||||
from ..utils.distributed_utils import is_torch_dist_rank_zero
|
||||
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import empty_device_cache, get_device, is_compiled_module
|
||||
|
||||
@@ -982,7 +983,11 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# 7. Load each module in the pipeline
|
||||
current_device_map = None
|
||||
_maybe_warn_for_wrong_component_in_quant_config(init_dict, quantization_config)
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
||||
logging_tqdm_kwargs = {"desc": "Loading pipeline components..."}
|
||||
if not is_torch_dist_rank_zero():
|
||||
logging_tqdm_kwargs["disable"] = True
|
||||
|
||||
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), **logging_tqdm_kwargs):
|
||||
# 7.1 device_map shenanigans
|
||||
if final_device_map is not None:
|
||||
if isinstance(final_device_map, dict) and len(final_device_map) > 0:
|
||||
@@ -1908,10 +1913,14 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
f"`self._progress_bar_config` should be of type `dict`, but is {type(self._progress_bar_config)}."
|
||||
)
|
||||
|
||||
progress_bar_config = dict(self._progress_bar_config)
|
||||
if "disable" not in progress_bar_config:
|
||||
progress_bar_config["disable"] = not is_torch_dist_rank_zero()
|
||||
|
||||
if iterable is not None:
|
||||
return tqdm(iterable, **self._progress_bar_config)
|
||||
return tqdm(iterable, **progress_bar_config)
|
||||
elif total is not None:
|
||||
return tqdm(total=total, **self._progress_bar_config)
|
||||
return tqdm(total=total, **progress_bar_config)
|
||||
else:
|
||||
raise ValueError("Either `total` or `iterable` has to be defined.")
|
||||
|
||||
|
||||
@@ -545,22 +545,24 @@ class SkyReelsV2Pipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixin):
|
||||
latent_model_input = latents.to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -887,25 +887,28 @@ class SkyReelsV2DiffusionForcingPipeline(DiffusionPipeline, SkyReelsV2LoraLoader
|
||||
)
|
||||
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
update_mask_i = step_update_mask[i]
|
||||
|
||||
@@ -966,25 +966,28 @@ class SkyReelsV2DiffusionForcingImageToVideoPipeline(DiffusionPipeline, SkyReels
|
||||
)
|
||||
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
update_mask_i = step_update_mask[i]
|
||||
|
||||
@@ -974,25 +974,28 @@ class SkyReelsV2DiffusionForcingVideoToVideoPipeline(DiffusionPipeline, SkyReels
|
||||
)
|
||||
timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
enable_diffusion_forcing=True,
|
||||
fps=fps_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
update_mask_i = step_update_mask[i]
|
||||
|
||||
@@ -678,24 +678,26 @@ class SkyReelsV2ImageToVideoPipeline(DiffusionPipeline, SkyReelsV2LoraLoaderMixi
|
||||
latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
|
||||
timestep = t.expand(latents.shape[0])
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond = self.transformer(
|
||||
with self.transformer.cache_context("cond"):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
with self.transformer.cache_context("uncond"):
|
||||
noise_uncond = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timestep=timestep,
|
||||
encoder_hidden_states=negative_prompt_embeds,
|
||||
encoder_hidden_states_image=image_embeds,
|
||||
attention_kwargs=attention_kwargs,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
|
||||
@@ -457,7 +457,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
- Shorthands: `float8wo`, `float8wo_e5m2`, `float8wo_e4m3`, `float8dq`, `float8dq_e4m3`,
|
||||
`float8_e4m3_tensor`, `float8_e4m3_row`,
|
||||
|
||||
- **Floating point X-bit quantization:**
|
||||
- **Floating point X-bit quantization:** (in torchao <= 0.14.1, not supported in torchao >= 0.15.0)
|
||||
- Full function names: `fpx_weight_only`
|
||||
- Shorthands: `fpX_eAwB`, where `X` is the number of bits (between `1` to `7`), `A` is the number
|
||||
of exponent bits and `B` is the number of mantissa bits. The constraint of `X == A + B + 1` must
|
||||
@@ -531,12 +531,18 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
||||
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
||||
if is_floating_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
is_floatx_quant_type = self.quant_type.startswith("fp")
|
||||
is_float_quant_type = self.quant_type.startswith("float") or is_floatx_quant_type
|
||||
if is_float_quant_type and not self._is_xpu_or_cuda_capability_atleast_8_9():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
||||
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
||||
)
|
||||
elif is_floatx_quant_type and not is_torchao_version("<=", "0.14.1"):
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is only supported in torchao <= 0.14.1. "
|
||||
f"Please downgrade to torchao <= 0.14.1 to use this quantization type."
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
||||
@@ -622,7 +628,6 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
float8_dynamic_activation_float8_weight,
|
||||
float8_static_activation_float8_weight,
|
||||
float8_weight_only,
|
||||
fpx_weight_only,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int4_weight,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
@@ -630,6 +635,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
uintx_weight_only,
|
||||
)
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
from torchao.quantization import fpx_weight_only
|
||||
# TODO(aryan): Add a note on how to use PerAxis and PerGroup observers
|
||||
from torchao.quantization.observer import PerRow, PerTensor
|
||||
|
||||
@@ -650,18 +657,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
return types
|
||||
|
||||
def generate_fpx_quantization_types(bits: int):
|
||||
types = {}
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
types = {}
|
||||
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
for ebits in range(1, bits):
|
||||
mbits = bits - ebits - 1
|
||||
types[f"fp{bits}_e{ebits}m{mbits}"] = partial(fpx_weight_only, ebits=ebits, mbits=mbits)
|
||||
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
non_sign_bits = bits - 1
|
||||
default_ebits = (non_sign_bits + 1) // 2
|
||||
default_mbits = non_sign_bits - default_ebits
|
||||
types[f"fp{bits}"] = partial(fpx_weight_only, ebits=default_ebits, mbits=default_mbits)
|
||||
|
||||
return types
|
||||
return types
|
||||
else:
|
||||
raise ValueError("Floating point X-bit quantization is not supported in torchao >= 0.15.0")
|
||||
|
||||
INT4_QUANTIZATION_TYPES = {
|
||||
# int4 weight + bfloat16/float16 activation
|
||||
@@ -710,15 +720,15 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
**generate_float8dq_types(torch.float8_e4m3fn),
|
||||
# float8 weight + float8 activation (static)
|
||||
"float8_static_activation_float8_weight": float8_static_activation_float8_weight,
|
||||
# For fpx, only x <= 8 is supported by default. Other dtypes can be explored by users directly
|
||||
# fpx weight + bfloat16/float16 activation
|
||||
**generate_fpx_quantization_types(3),
|
||||
**generate_fpx_quantization_types(4),
|
||||
**generate_fpx_quantization_types(5),
|
||||
**generate_fpx_quantization_types(6),
|
||||
**generate_fpx_quantization_types(7),
|
||||
}
|
||||
|
||||
if is_torchao_version("<=", "0.14.1"):
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(3))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(4))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(5))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(6))
|
||||
FLOATX_QUANTIZATION_TYPES.update(generate_fpx_quantization_types(7))
|
||||
|
||||
UINTX_QUANTIZATION_DTYPES = {
|
||||
"uintx_weight_only": uintx_weight_only,
|
||||
"uint1wo": partial(uintx_weight_only, dtype=torch.uint1),
|
||||
|
||||
@@ -66,6 +66,7 @@ else:
|
||||
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
|
||||
_import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
|
||||
_import_structure["scheduling_lcm"] = ["LCMScheduler"]
|
||||
_import_structure["scheduling_ltx_euler_ancestral_rf"] = ["LTXEulerAncestralRFScheduler"]
|
||||
_import_structure["scheduling_pndm"] = ["PNDMScheduler"]
|
||||
_import_structure["scheduling_repaint"] = ["RePaintScheduler"]
|
||||
_import_structure["scheduling_sasolver"] = ["SASolverScheduler"]
|
||||
@@ -168,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
|
||||
from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
|
||||
from .scheduling_lcm import LCMScheduler
|
||||
from .scheduling_ltx_euler_ancestral_rf import LTXEulerAncestralRFScheduler
|
||||
from .scheduling_pndm import PNDMScheduler
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sasolver import SASolverScheduler
|
||||
|
||||
@@ -71,6 +71,22 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
|
||||
|
||||
|
||||
class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
A scheduler for the consistency decoder used in Stable Diffusion pipelines.
|
||||
|
||||
This scheduler implements a two-step denoising process using consistency models for decoding latent representations
|
||||
into images.
|
||||
|
||||
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
||||
methods the library implements for all schedulers such as loading and saving.
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`, *optional*, defaults to `1024`):
|
||||
The number of diffusion steps to train the model.
|
||||
sigma_data (`float`, *optional*, defaults to `0.5`):
|
||||
The standard deviation of the data distribution. Used for computing the skip and output scaling factors.
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
@@ -78,7 +94,7 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_train_timesteps: int = 1024,
|
||||
sigma_data: float = 0.5,
|
||||
):
|
||||
) -> None:
|
||||
betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
|
||||
alphas = 1.0 - betas
|
||||
@@ -98,8 +114,18 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
):
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`, *optional*):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model. Currently, only
|
||||
`2` inference steps are supported.
|
||||
device (`str` or `torch.device`, *optional*):
|
||||
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
||||
"""
|
||||
if num_inference_steps != 2:
|
||||
raise ValueError("Currently more than 2 inference steps are not supported.")
|
||||
|
||||
@@ -111,7 +137,15 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.c_in = self.c_in.to(device)
|
||||
|
||||
@property
|
||||
def init_noise_sigma(self):
|
||||
def init_noise_sigma(self) -> torch.Tensor:
|
||||
"""
|
||||
Return the standard deviation of the initial noise distribution.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
The initial noise sigma value from the precomputed `sqrt_one_minus_alphas_cumprod` at the first
|
||||
timestep.
|
||||
"""
|
||||
return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||
@@ -146,20 +180,20 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
model_output (`torch.Tensor`):
|
||||
The direct output from the learned diffusion model.
|
||||
timestep (`float`):
|
||||
timestep (`float` or `torch.Tensor`):
|
||||
The current timestep in the diffusion chain.
|
||||
sample (`torch.Tensor`):
|
||||
A current instance of a sample created by the diffusion process.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A random number generator.
|
||||
A random number generator for reproducibility.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`.
|
||||
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`.
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] or `tuple`:
|
||||
If return_dict is `True`,
|
||||
[`~schedulers.scheduling_consistency_models.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
|
||||
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] or `tuple`:
|
||||
If `return_dict` is `True`,
|
||||
[`~schedulers.scheduling_consistency_decoder.ConsistencyDecoderSchedulerOutput`] is returned, otherwise
|
||||
a tuple is returned where the first element is the sample tensor.
|
||||
"""
|
||||
x_0 = self.c_out[timestep] * model_output + self.c_skip[timestep] * sample
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user