mirror of
https://github.com/huggingface/diffusers.git
synced 2025-12-12 23:44:30 +08:00
Compare commits
211 Commits
kontext-re
...
test-clean
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
de7cdf6287 | ||
|
|
73c5fe8bb1 | ||
|
|
595581d6ba | ||
|
|
d27b65411e | ||
|
|
cb9dca5523 | ||
|
|
79166dcb47 | ||
|
|
01240fecb0 | ||
|
|
ce338d4e4a | ||
|
|
f95c320467 | ||
|
|
59abd9514b | ||
|
|
5f3ebef0d7 | ||
|
|
e6ffde2936 | ||
|
|
04171c7345 | ||
|
|
be5e10ae61 | ||
|
|
a2da0004ee | ||
|
|
863c7df543 | ||
|
|
bc55b631fd | ||
|
|
e0083b29d5 | ||
|
|
6521f599b2 | ||
|
|
0fcce2acd8 | ||
|
|
15d50f16f2 | ||
|
|
ceeb3c1da3 | ||
|
|
2c30287958 | ||
|
|
0fcdd699cf | ||
|
|
5af003a9e1 | ||
|
|
179d6d958b | ||
|
|
229c4b355c | ||
|
|
0a4819a755 | ||
|
|
7cea9a3bb0 | ||
|
|
23de59e21a | ||
|
|
4f8b6f5a15 | ||
|
|
63e94cbc61 | ||
|
|
2c66fb3a85 | ||
|
|
284f827d6c | ||
|
|
b750c69859 | ||
|
|
13c51bb038 | ||
|
|
425a715e35 | ||
|
|
2527917528 | ||
|
|
e6639fef70 | ||
|
|
8c938fb410 | ||
|
|
f864a9a352 | ||
|
|
d6fa3298fa | ||
|
|
6f1d6694df | ||
|
|
0e95aa853e | ||
|
|
5ef74fd5f6 | ||
|
|
64a9210315 | ||
|
|
d31b8cea3e | ||
|
|
62e847db5f | ||
|
|
470458623e | ||
|
|
a79c3af6bb | ||
|
|
3f3f0c16a6 | ||
|
|
f3e1310469 | ||
|
|
87f83d3dd9 | ||
|
|
3e46c86a93 | ||
|
|
8cb5b084b5 | ||
|
|
13fe248152 | ||
|
|
2e2024152c | ||
|
|
1987c07899 | ||
|
|
4543d216ec | ||
|
|
b5db8aaa6f | ||
|
|
98ea5c9e86 | ||
|
|
f27fbceba1 | ||
|
|
4b12a60c93 | ||
|
|
abf28d55fb | ||
|
|
f064b3bf73 | ||
|
|
db4b54cfab | ||
|
|
0138e176ac | ||
|
|
3b079ec3fa | ||
|
|
bc34fa8386 | ||
|
|
bbd9340781 | ||
|
|
363737ec4b | ||
|
|
c5849ba9d5 | ||
|
|
f09b1ccfae | ||
|
|
285f877620 | ||
|
|
c75b88f86f | ||
|
|
b43e703fae | ||
|
|
9fae3828a7 | ||
|
|
3a3441cb45 | ||
|
|
fdd2bedae9 | ||
|
|
fedaa00bd5 | ||
|
|
8c680bc0b4 | ||
|
|
92b6b43805 | ||
|
|
49ea4d1bf5 | ||
|
|
58dbe0c29e | ||
|
|
9aaec5b9bc | ||
|
|
93760b1888 | ||
|
|
75540f42ee | ||
|
|
b543bcc661 | ||
|
|
885a596696 | ||
|
|
655512e2cf | ||
|
|
05e7a854d0 | ||
|
|
76ec3d1fee | ||
|
|
cdaf84a708 | ||
|
|
e8e44a510c | ||
|
|
f63d62e091 | ||
|
|
21543de571 | ||
|
|
7608d2eb9e | ||
|
|
449f299c63 | ||
|
|
84f4b27dfa | ||
|
|
9abac85f77 | ||
|
|
61772f0994 | ||
|
|
b92cda25e2 | ||
|
|
7492e331b4 | ||
|
|
ab6d63407a | ||
|
|
da4242d467 | ||
|
|
129d658da7 | ||
|
|
75e62385f5 | ||
|
|
a33206d22b | ||
|
|
a82e211f89 | ||
|
|
f3453f05ff | ||
|
|
c437ae72c6 | ||
|
|
9530245e17 | ||
|
|
74b908b7e2 | ||
|
|
7d2a633e02 | ||
|
|
cb328d3ff9 | ||
|
|
8c038f0e62 | ||
|
|
5917d7039f | ||
|
|
c0327e493e | ||
|
|
174628edf4 | ||
|
|
1c9f0a83c9 | ||
|
|
cdaaa40d31 | ||
|
|
ffbaa890ba | ||
|
|
e49413d87d | ||
|
|
48e4ff5c05 | ||
|
|
7c78fb1aad | ||
|
|
bb4044362e | ||
|
|
1ae591e817 | ||
|
|
42c06e90f4 | ||
|
|
085ade03be | ||
|
|
78d2454c7c | ||
|
|
19545fd3e1 | ||
|
|
d12531ddf7 | ||
|
|
4751d456f2 | ||
|
|
083479c365 | ||
|
|
04c16d0a56 | ||
|
|
9e58856b7a | ||
|
|
45392cce11 | ||
|
|
8913d59bf3 | ||
|
|
5a8c1b5f19 | ||
|
|
7ad01a6350 | ||
|
|
a8e853b791 | ||
|
|
6a509ba862 | ||
|
|
96795afc72 | ||
|
|
12650e1393 | ||
|
|
addaad013c | ||
|
|
485f8d1758 | ||
|
|
cff0fd6260 | ||
|
|
8ddb20bfb8 | ||
|
|
e5089d702b | ||
|
|
2c3e4eafa8 | ||
|
|
c7020df2cf | ||
|
|
4bed3e306e | ||
|
|
00a3bc9d6c | ||
|
|
ccb35acd81 | ||
|
|
00cae4e857 | ||
|
|
b3fb4188f5 | ||
|
|
71df1581f7 | ||
|
|
d046cf7d35 | ||
|
|
68a5185c86 | ||
|
|
6e2fe26bfd | ||
|
|
77b5fa59c5 | ||
|
|
a226920b52 | ||
|
|
7007f72409 | ||
|
|
a6804de4a2 | ||
|
|
7f897a9fc4 | ||
|
|
0966663d2a | ||
|
|
fb78f4f12d | ||
|
|
2220af6940 | ||
|
|
7a34832d52 | ||
|
|
e973de64f9 | ||
|
|
db94ca882d | ||
|
|
6985906a2e | ||
|
|
54f410db6c | ||
|
|
c12a05b9c1 | ||
|
|
2e0f5c86cc | ||
|
|
1d63306295 | ||
|
|
6c93626f6f | ||
|
|
72c5bf07c8 | ||
|
|
ed59f90f15 | ||
|
|
a09ca7f27e | ||
|
|
8c02572e16 | ||
|
|
27dde51de8 | ||
|
|
10d4a775f1 | ||
|
|
72d9a81d99 | ||
|
|
4fa85c7963 | ||
|
|
806e8e66fb | ||
|
|
0b90051db8 | ||
|
|
b305c779b2 | ||
|
|
2b3cd2d39c | ||
|
|
bc3d1c9ee6 | ||
|
|
e50d614636 | ||
|
|
a8df0f1ffb | ||
|
|
ace53e2d2f | ||
|
|
ffc2992fc2 | ||
|
|
c70a285c2c | ||
|
|
8b811feece | ||
|
|
37e8dc7a59 | ||
|
|
024a9f5de3 | ||
|
|
005195c23e | ||
|
|
6742f160df | ||
|
|
540d303250 | ||
|
|
f1b3036ca1 | ||
|
|
46ec1743a2 | ||
|
|
70272b1108 | ||
|
|
2b6dcbfa1d | ||
|
|
af9572d759 | ||
|
|
ddea157979 | ||
|
|
ad3f9a26c0 | ||
|
|
e8d0980f9f | ||
|
|
52a7f1cb97 | ||
|
|
33f85fadf6 |
41
.github/workflows/benchmark.yml
vendored
41
.github/workflows/benchmark.yml
vendored
@@ -11,17 +11,18 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
BASE_PATH: benchmark_outputs
|
||||
|
||||
jobs:
|
||||
torch_pipelines_cuda_benchmark_tests:
|
||||
torch_models_cuda_benchmark_tests:
|
||||
env:
|
||||
SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL_BENCHMARK }}
|
||||
name: Torch Core Pipelines CUDA Benchmarking Tests
|
||||
name: Torch Core Models CUDA Benchmarking Tests
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1
|
||||
runs-on:
|
||||
group: aws-g6-4xlarge-plus
|
||||
group: aws-g6e-4xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-pytorch-cuda
|
||||
options: --shm-size "16gb" --ipc host --gpus 0
|
||||
@@ -35,27 +36,47 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
apt update
|
||||
apt install -y libpq-dev postgresql-client
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
python -m uv pip install pandas peft
|
||||
python -m uv pip uninstall transformers && python -m uv pip install transformers==4.48.0
|
||||
python -m uv pip install -r benchmarks/requirements.txt
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
- name: Diffusers Benchmarking
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
|
||||
BASE_PATH: benchmark_outputs
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
export TOTAL_GPU_MEMORY=$(python -c "import torch; print(torch.cuda.get_device_properties(0).total_memory / (1024**3))")
|
||||
cd benchmarks && mkdir ${BASE_PATH} && python run_all.py && python push_results.py
|
||||
cd benchmarks && python run_all.py
|
||||
|
||||
- name: Push results to the Hub
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_BOT_TOKEN }}
|
||||
run: |
|
||||
cd benchmarks && python push_results.py
|
||||
mkdir $BASE_PATH && cp *.csv $BASE_PATH
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/benchmark_outputs
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
# TODO: enable this once the connection problem has been resolved.
|
||||
- name: Update benchmarking results to DB
|
||||
env:
|
||||
PGDATABASE: metrics
|
||||
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
|
||||
PGUSER: transformers_benchmarks
|
||||
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
run: |
|
||||
git config --global --add safe.directory /__w/diffusers/diffusers
|
||||
commit_id=$GITHUB_SHA
|
||||
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
|
||||
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
|
||||
|
||||
- name: Report success status
|
||||
if: ${{ success() }}
|
||||
|
||||
4
.github/workflows/build_docker_images.yml
vendored
4
.github/workflows/build_docker_images.yml
vendored
@@ -75,10 +75,6 @@ jobs:
|
||||
- diffusers-pytorch-cuda
|
||||
- diffusers-pytorch-xformers-cuda
|
||||
- diffusers-pytorch-minimum-cuda
|
||||
- diffusers-flax-cpu
|
||||
- diffusers-flax-tpu
|
||||
- diffusers-onnxruntime-cpu
|
||||
- diffusers-onnxruntime-cuda
|
||||
- diffusers-doc-builder
|
||||
|
||||
steps:
|
||||
|
||||
106
.github/workflows/nightly_tests.yml
vendored
106
.github/workflows/nightly_tests.yml
vendored
@@ -248,7 +248,7 @@ jobs:
|
||||
BIG_GPU_MEMORY: 40
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-m "big_gpu_with_torch_cuda" \
|
||||
-m "big_accelerator" \
|
||||
--make-reports=tests_big_gpu_torch_cuda \
|
||||
--report-log=tests_big_gpu_torch_cuda.log \
|
||||
tests/
|
||||
@@ -321,55 +321,6 @@ jobs:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_onnx_tests:
|
||||
name: Nightly ONNXRuntime CUDA tests on Ubuntu
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --gpus 0 --shm-size "16gb" --ipc host
|
||||
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: NVIDIA-SMI
|
||||
run: nvidia-smi
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run Nightly ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
--report-log=tests_onnx_cuda.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: tests_onnx_cuda_reports
|
||||
path: reports
|
||||
|
||||
run_nightly_quantization_tests:
|
||||
name: Torch quantization nightly tests
|
||||
strategy:
|
||||
@@ -485,57 +436,6 @@ jobs:
|
||||
name: torch_cuda_pipeline_level_quant_reports
|
||||
path: reports
|
||||
|
||||
run_flax_tpu_tests:
|
||||
name: Nightly Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
if: github.event_name == 'schedule'
|
||||
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
python -m uv pip install pytest-reportlog
|
||||
|
||||
- name: Environment
|
||||
run: python utils/print_env.py
|
||||
|
||||
- name: Run nightly Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
--report-log=tests_flax_tpu.log \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
generate_consolidated_report:
|
||||
name: Generate Consolidated Test Report
|
||||
needs: [
|
||||
@@ -545,9 +445,9 @@ jobs:
|
||||
run_big_gpu_torch_tests,
|
||||
run_nightly_quantization_tests,
|
||||
run_nightly_pipeline_level_quantization_tests,
|
||||
run_nightly_onnx_tests,
|
||||
# run_nightly_onnx_tests,
|
||||
torch_minimum_version_cuda_tests,
|
||||
run_flax_tpu_tests
|
||||
# run_flax_tpu_tests
|
||||
]
|
||||
if: always()
|
||||
runs-on:
|
||||
|
||||
14
.github/workflows/pr_tests.yml
vendored
14
.github/workflows/pr_tests.yml
vendored
@@ -87,11 +87,6 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu_models_schedulers
|
||||
- name: Fast Flax CPU tests
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: PyTorch Example CPU tests
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -147,15 +142,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/models tests/schedulers tests/others
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
96
.github/workflows/push_tests.yml
vendored
96
.github/workflows/push_tests.yml
vendored
@@ -159,102 +159,6 @@ jobs:
|
||||
name: torch_cuda_test_reports_${{ matrix.module }}
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on:
|
||||
group: gcp-ct5lp-hightpu-8t
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host --privileged ${{ vars.V5_LITEPOD_8_ENV}} -v /mnt/hf_cache:/mnt/hf_cache
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
|
||||
28
.github/workflows/push_tests_fast.yml
vendored
28
.github/workflows/push_tests_fast.yml
vendored
@@ -33,16 +33,6 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_cpu
|
||||
- name: Fast Flax CPU tests on Ubuntu
|
||||
framework: flax
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-flax-cpu
|
||||
report: flax_cpu
|
||||
- name: Fast ONNXRuntime CPU tests on Ubuntu
|
||||
framework: onnxruntime
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-onnxruntime-cpu
|
||||
report: onnx_cpu
|
||||
- name: PyTorch Example CPU tests on Ubuntu
|
||||
framework: pytorch_examples
|
||||
runner: aws-general-8-plus
|
||||
@@ -87,24 +77,6 @@ jobs:
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast Flax TPU tests
|
||||
if: ${{ matrix.config.framework == 'flax' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run fast ONNXRuntime CPU tests
|
||||
if: ${{ matrix.config.framework == 'onnxruntime' }}
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_${{ matrix.config.report }} \
|
||||
tests/
|
||||
|
||||
- name: Run example PyTorch CPU tests
|
||||
if: ${{ matrix.config.framework == 'pytorch_examples' }}
|
||||
run: |
|
||||
|
||||
7
.github/workflows/push_tests_mps.yml
vendored
7
.github/workflows/push_tests_mps.yml
vendored
@@ -1,12 +1,7 @@
|
||||
name: Fast mps tests on main
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "src/diffusers/**.py"
|
||||
- "tests/**.py"
|
||||
workflow_dispatch:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
|
||||
95
.github/workflows/release_tests_fast.yml
vendored
95
.github/workflows/release_tests_fast.yml
vendored
@@ -213,101 +213,6 @@ jobs:
|
||||
with:
|
||||
name: torch_minimum_version_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
flax_tpu_tests:
|
||||
name: Flax TPU Tests
|
||||
runs-on: docker-tpu
|
||||
container:
|
||||
image: diffusers/diffusers-flax-tpu
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --privileged
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow Flax TPU tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 0 \
|
||||
-s -v -k "Flax" \
|
||||
--make-reports=tests_flax_tpu \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_flax_tpu_stats.txt
|
||||
cat reports/tests_flax_tpu_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: flax_tpu_test_reports
|
||||
path: reports
|
||||
|
||||
onnx_cuda_tests:
|
||||
name: ONNX CUDA Tests
|
||||
runs-on:
|
||||
group: aws-g4dn-2xlarge
|
||||
container:
|
||||
image: diffusers/diffusers-onnxruntime-cuda
|
||||
options: --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/ --gpus 0
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
steps:
|
||||
- name: Checkout diffusers
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 2
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
|
||||
python -m uv pip install -e [quality,test]
|
||||
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
- name: Run slow ONNXRuntime CUDA tests
|
||||
env:
|
||||
HF_TOKEN: ${{ secrets.DIFFUSERS_HF_HUB_READ_TOKEN }}
|
||||
run: |
|
||||
python -m pytest -n 1 --max-worker-restart=0 --dist=loadfile \
|
||||
-s -v -k "Onnx" \
|
||||
--make-reports=tests_onnx_cuda \
|
||||
tests/
|
||||
|
||||
- name: Failure short reports
|
||||
if: ${{ failure() }}
|
||||
run: |
|
||||
cat reports/tests_onnx_cuda_stats.txt
|
||||
cat reports/tests_onnx_cuda_failures_short.txt
|
||||
|
||||
- name: Test suite reports artifacts
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: onnx_cuda_test_reports
|
||||
path: reports
|
||||
|
||||
run_torch_compile_tests:
|
||||
name: PyTorch Compile CUDA tests
|
||||
|
||||
69
benchmarks/README.md
Normal file
69
benchmarks/README.md
Normal file
@@ -0,0 +1,69 @@
|
||||
# Diffusers Benchmarks
|
||||
|
||||
Welcome to Diffusers Benchmarks. These benchmarks are use to obtain latency and memory information of the most popular models across different scenarios such as:
|
||||
|
||||
* Base case i.e., when using `torch.bfloat16` and `torch.nn.functional.scaled_dot_product_attention`.
|
||||
* Base + `torch.compile()`
|
||||
* NF4 quantization
|
||||
* Layerwise upcasting
|
||||
|
||||
Instead of full diffusion pipelines, only the forward pass of the respective model classes (such as `FluxTransformer2DModel`) is tested with the real checkpoints (such as `"black-forest-labs/FLUX.1-dev"`).
|
||||
|
||||
The entrypoint to running all the currently available benchmarks is in `run_all.py`. However, one can run the individual benchmarks, too, e.g., `python benchmarking_flux.py`. It should produce a CSV file containing various information about the benchmarks run.
|
||||
|
||||
The benchmarks are run on a weekly basis and the CI is defined in [benchmark.yml](../.github/workflows/benchmark.yml).
|
||||
|
||||
## Running the benchmarks manually
|
||||
|
||||
First set up `torch` and install `diffusers` from the root of the directory:
|
||||
|
||||
```py
|
||||
pip install -e ".[quality,test]"
|
||||
```
|
||||
|
||||
Then make sure the other dependencies are installed:
|
||||
|
||||
```sh
|
||||
cd benchmarks/
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
We need to be authenticated to access some of the checkpoints used during benchmarking:
|
||||
|
||||
```sh
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
We use an L40 GPU with 128GB RAM to run the benchmark CI. As such, the benchmarks are configured to run on NVIDIA GPUs. So, make sure you have access to a similar machine (or modify the benchmarking scripts accordingly).
|
||||
|
||||
Then you can either launch the entire benchmarking suite by running:
|
||||
|
||||
```sh
|
||||
python run_all.py
|
||||
```
|
||||
|
||||
Or, you can run the individual benchmarks.
|
||||
|
||||
## Customizing the benchmarks
|
||||
|
||||
We define "scenarios" to cover the most common ways in which these models are used. You can
|
||||
define a new scenario, modifying an existing benchmark file:
|
||||
|
||||
```py
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bnb-8bit",
|
||||
model_cls=FluxTransformer2DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
"quantization_config": BitsAndBytesConfig(load_in_8bit=True),
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
)
|
||||
```
|
||||
|
||||
You can also configure a new model-level benchmark and add it to the existing suite. To do so, just defining a valid benchmarking file like `benchmarking_flux.py` should be enough.
|
||||
|
||||
Happy benchmarking 🧨
|
||||
@@ -1,346 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoPipelineForImage2Image,
|
||||
AutoPipelineForInpainting,
|
||||
AutoPipelineForText2Image,
|
||||
ControlNetModel,
|
||||
LCMScheduler,
|
||||
StableDiffusionAdapterPipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionXLAdapterPipeline,
|
||||
StableDiffusionXLControlNetPipeline,
|
||||
T2IAdapter,
|
||||
WuerstchenCombinedPipeline,
|
||||
)
|
||||
from diffusers.utils import load_image
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
|
||||
from utils import ( # noqa: E402
|
||||
BASE_PATH,
|
||||
PROMPT,
|
||||
BenchmarkInfo,
|
||||
benchmark_fn,
|
||||
bytes_to_giga_bytes,
|
||||
flush,
|
||||
generate_csv_dict,
|
||||
write_to_csv,
|
||||
)
|
||||
|
||||
|
||||
RESOLUTION_MAPPING = {
|
||||
"Lykon/DreamShaper": (512, 512),
|
||||
"lllyasviel/sd-controlnet-canny": (512, 512),
|
||||
"diffusers/controlnet-canny-sdxl-1.0": (1024, 1024),
|
||||
"TencentARC/t2iadapter_canny_sd14v1": (512, 512),
|
||||
"TencentARC/t2i-adapter-canny-sdxl-1.0": (1024, 1024),
|
||||
"stabilityai/stable-diffusion-2-1": (768, 768),
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": (1024, 1024),
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0": (1024, 1024),
|
||||
"stabilityai/sdxl-turbo": (512, 512),
|
||||
}
|
||||
|
||||
|
||||
class BaseBenchmak:
|
||||
pipeline_class = None
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
|
||||
def run_inference(self, args):
|
||||
raise NotImplementedError
|
||||
|
||||
def benchmark(self, args):
|
||||
raise NotImplementedError
|
||||
|
||||
def get_result_filepath(self, args):
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
name = (
|
||||
args.ckpt.replace("/", "_")
|
||||
+ "_"
|
||||
+ pipeline_class_name
|
||||
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
|
||||
)
|
||||
filepath = os.path.join(BASE_PATH, name)
|
||||
return filepath
|
||||
|
||||
|
||||
class TextToImageBenchmark(BaseBenchmak):
|
||||
pipeline_class = AutoPipelineForText2Image
|
||||
|
||||
def __init__(self, args):
|
||||
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
if args.run_compile:
|
||||
if not isinstance(pipe, WuerstchenCombinedPipeline):
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
if hasattr(pipe, "movq") and getattr(pipe, "movq", None) is not None:
|
||||
pipe.movq.to(memory_format=torch.channels_last)
|
||||
pipe.movq = torch.compile(pipe.movq, mode="reduce-overhead", fullgraph=True)
|
||||
else:
|
||||
print("Run torch compile")
|
||||
pipe.decoder = torch.compile(pipe.decoder, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.vqgan = torch.compile(pipe.vqgan, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
def benchmark(self, args):
|
||||
flush()
|
||||
|
||||
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
|
||||
|
||||
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
|
||||
benchmark_info = BenchmarkInfo(time=time, memory=memory)
|
||||
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
flush()
|
||||
csv_dict = generate_csv_dict(
|
||||
pipeline_cls=pipeline_class_name, ckpt=args.ckpt, args=args, benchmark_info=benchmark_info
|
||||
)
|
||||
filepath = self.get_result_filepath(args)
|
||||
write_to_csv(filepath, csv_dict)
|
||||
print(f"Logs written to: {filepath}")
|
||||
flush()
|
||||
|
||||
|
||||
class TurboTextToImageBenchmark(TextToImageBenchmark):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=0.0,
|
||||
)
|
||||
|
||||
|
||||
class LCMLoRATextToImageBenchmark(TextToImageBenchmark):
|
||||
lora_id = "latent-consistency/lcm-lora-sdxl"
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.pipe.load_lora_weights(self.lora_id)
|
||||
self.pipe.fuse_lora()
|
||||
self.pipe.unload_lora_weights()
|
||||
self.pipe.scheduler = LCMScheduler.from_config(self.pipe.scheduler.config)
|
||||
|
||||
def get_result_filepath(self, args):
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
name = (
|
||||
self.lora_id.replace("/", "_")
|
||||
+ "_"
|
||||
+ pipeline_class_name
|
||||
+ f"-bs@{args.batch_size}-steps@{args.num_inference_steps}-mco@{args.model_cpu_offload}-compile@{args.run_compile}.csv"
|
||||
)
|
||||
filepath = os.path.join(BASE_PATH, name)
|
||||
return filepath
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=1.0,
|
||||
)
|
||||
|
||||
def benchmark(self, args):
|
||||
flush()
|
||||
|
||||
print(f"[INFO] {self.pipe.__class__.__name__}: Running benchmark with: {vars(args)}\n")
|
||||
|
||||
time = benchmark_fn(self.run_inference, self.pipe, args) # in seconds.
|
||||
memory = bytes_to_giga_bytes(torch.cuda.max_memory_allocated()) # in GBs.
|
||||
benchmark_info = BenchmarkInfo(time=time, memory=memory)
|
||||
|
||||
pipeline_class_name = str(self.pipe.__class__.__name__)
|
||||
flush()
|
||||
csv_dict = generate_csv_dict(
|
||||
pipeline_cls=pipeline_class_name, ckpt=self.lora_id, args=args, benchmark_info=benchmark_info
|
||||
)
|
||||
filepath = self.get_result_filepath(args)
|
||||
write_to_csv(filepath, csv_dict)
|
||||
print(f"Logs written to: {filepath}")
|
||||
flush()
|
||||
|
||||
|
||||
class ImageToImageBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForImage2Image
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/1665_Girl_with_a_Pearl_Earring.jpg"
|
||||
image = load_image(url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class TurboImageToImageBenchmark(ImageToImageBenchmark):
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
guidance_scale=0.0,
|
||||
strength=0.5,
|
||||
)
|
||||
|
||||
|
||||
class InpaintingBenchmark(ImageToImageBenchmark):
|
||||
pipeline_class = AutoPipelineForInpainting
|
||||
mask_url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/overture-creations-5sI6fQgYIuo_mask.png"
|
||||
mask = load_image(mask_url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
self.mask = self.mask.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
mask_image=self.mask,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterTextToImageBenchmark(TextToImageBenchmark):
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/load_neg_embed.png"
|
||||
image = load_image(url)
|
||||
|
||||
def __init__(self, args):
|
||||
pipe = self.pipeline_class.from_pretrained(args.ckpt, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.load_ip_adapter(
|
||||
args.ip_adapter_id[0],
|
||||
subfolder="models" if "sdxl" not in args.ip_adapter_id[1] else "sdxl_models",
|
||||
weight_name=args.ip_adapter_id[1],
|
||||
)
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
ip_adapter_image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetBenchmark(TextToImageBenchmark):
|
||||
pipeline_class = StableDiffusionControlNetPipeline
|
||||
aux_network_class = ControlNetModel
|
||||
root_ckpt = "Lykon/DreamShaper"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_image_condition.png"
|
||||
image = load_image(url).convert("RGB")
|
||||
|
||||
def __init__(self, args):
|
||||
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, controlnet=aux_network, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.controlnet.to(memory_format=torch.channels_last)
|
||||
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.controlnet = torch.compile(pipe.controlnet, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
def run_inference(self, pipe, args):
|
||||
_ = pipe(
|
||||
prompt=PROMPT,
|
||||
image=self.image,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
num_images_per_prompt=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
class ControlNetSDXLBenchmark(ControlNetBenchmark):
|
||||
pipeline_class = StableDiffusionXLControlNetPipeline
|
||||
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
|
||||
|
||||
class T2IAdapterBenchmark(ControlNetBenchmark):
|
||||
pipeline_class = StableDiffusionAdapterPipeline
|
||||
aux_network_class = T2IAdapter
|
||||
root_ckpt = "Lykon/DreamShaper"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter.png"
|
||||
image = load_image(url).convert("L")
|
||||
|
||||
def __init__(self, args):
|
||||
aux_network = self.aux_network_class.from_pretrained(args.ckpt, torch_dtype=torch.float16)
|
||||
pipe = self.pipeline_class.from_pretrained(self.root_ckpt, adapter=aux_network, torch_dtype=torch.float16)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
pipe.set_progress_bar_config(disable=True)
|
||||
self.pipe = pipe
|
||||
|
||||
if args.run_compile:
|
||||
pipe.unet.to(memory_format=torch.channels_last)
|
||||
pipe.adapter.to(memory_format=torch.channels_last)
|
||||
|
||||
print("Run torch compile")
|
||||
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
||||
pipe.adapter = torch.compile(pipe.adapter, mode="reduce-overhead", fullgraph=True)
|
||||
|
||||
self.image = self.image.resize(RESOLUTION_MAPPING[args.ckpt])
|
||||
|
||||
|
||||
class T2IAdapterSDXLBenchmark(T2IAdapterBenchmark):
|
||||
pipeline_class = StableDiffusionXLAdapterPipeline
|
||||
root_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
|
||||
url = "https://huggingface.co/datasets/diffusers/docs-images/resolve/main/benchmarking/canny_for_adapter_sdxl.png"
|
||||
image = load_image(url)
|
||||
|
||||
def __init__(self, args):
|
||||
super().__init__(args)
|
||||
@@ -1,26 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import ControlNetBenchmark, ControlNetSDXLBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="lllyasviel/sd-controlnet-canny",
|
||||
choices=["lllyasviel/sd-controlnet-canny", "diffusers/controlnet-canny-sdxl-1.0"],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = (
|
||||
ControlNetBenchmark(args) if args.ckpt == "lllyasviel/sd-controlnet-canny" else ControlNetSDXLBenchmark(args)
|
||||
)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,33 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import IPAdapterTextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
IP_ADAPTER_CKPTS = {
|
||||
# because original SD v1.5 has been taken down.
|
||||
"Lykon/DreamShaper": ("h94/IP-Adapter", "ip-adapter_sd15.bin"),
|
||||
"stabilityai/stable-diffusion-xl-base-1.0": ("h94/IP-Adapter", "ip-adapter_sdxl.bin"),
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="rstabilityai/stable-diffusion-xl-base-1.0",
|
||||
choices=list(IP_ADAPTER_CKPTS.keys()),
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
args.ip_adapter_id = IP_ADAPTER_CKPTS[args.ckpt]
|
||||
benchmark_pipe = IPAdapterTextToImageBenchmark(args)
|
||||
args.ckpt = f"{args.ckpt} (IP-Adapter)"
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,29 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import ImageToImageBenchmark, TurboImageToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="Lykon/DreamShaper",
|
||||
choices=[
|
||||
"Lykon/DreamShaper",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-refiner-1.0",
|
||||
"stabilityai/sdxl-turbo",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = ImageToImageBenchmark(args) if "turbo" not in args.ckpt else TurboImageToImageBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,28 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import InpaintingBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="Lykon/DreamShaper",
|
||||
choices=[
|
||||
"Lykon/DreamShaper",
|
||||
"stabilityai/stable-diffusion-2-1",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = InpaintingBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,28 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import T2IAdapterBenchmark, T2IAdapterSDXLBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="TencentARC/t2iadapter_canny_sd14v1",
|
||||
choices=["TencentARC/t2iadapter_canny_sd14v1", "TencentARC/t2i-adapter-canny-sdxl-1.0"],
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = (
|
||||
T2IAdapterBenchmark(args)
|
||||
if args.ckpt == "TencentARC/t2iadapter_canny_sd14v1"
|
||||
else T2IAdapterSDXLBenchmark(args)
|
||||
)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,23 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import LCMLoRATextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="stabilityai/stable-diffusion-xl-base-1.0",
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=4)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_pipe = LCMLoRATextToImageBenchmark(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
@@ -1,40 +0,0 @@
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from base_classes import TextToImageBenchmark, TurboTextToImageBenchmark # noqa: E402
|
||||
|
||||
|
||||
ALL_T2I_CKPTS = [
|
||||
"Lykon/DreamShaper",
|
||||
"segmind/SSD-1B",
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"kandinsky-community/kandinsky-2-2-decoder",
|
||||
"warp-ai/wuerstchen",
|
||||
"stabilityai/sdxl-turbo",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--ckpt",
|
||||
type=str,
|
||||
default="Lykon/DreamShaper",
|
||||
choices=ALL_T2I_CKPTS,
|
||||
)
|
||||
parser.add_argument("--batch_size", type=int, default=1)
|
||||
parser.add_argument("--num_inference_steps", type=int, default=50)
|
||||
parser.add_argument("--model_cpu_offload", action="store_true")
|
||||
parser.add_argument("--run_compile", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
benchmark_cls = None
|
||||
if "turbo" in args.ckpt:
|
||||
benchmark_cls = TurboTextToImageBenchmark
|
||||
else:
|
||||
benchmark_cls = TextToImageBenchmark
|
||||
|
||||
benchmark_pipe = benchmark_cls(args)
|
||||
benchmark_pipe.benchmark(args)
|
||||
98
benchmarks/benchmarking_flux.py
Normal file
98
benchmarks/benchmarking_flux.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
||||
|
||||
from diffusers import BitsAndBytesConfig, FluxTransformer2DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
CKPT_ID = "black-forest-labs/FLUX.1-dev"
|
||||
RESULT_FILENAME = "flux.csv"
|
||||
|
||||
|
||||
def get_input_dict(**device_dtype_kwargs):
|
||||
# resolution: 1024x1024
|
||||
# maximum sequence length 512
|
||||
hidden_states = torch.randn(1, 4096, 64, **device_dtype_kwargs)
|
||||
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
|
||||
pooled_prompt_embeds = torch.randn(1, 768, **device_dtype_kwargs)
|
||||
image_ids = torch.ones(512, 3, **device_dtype_kwargs)
|
||||
text_ids = torch.ones(4096, 3, **device_dtype_kwargs)
|
||||
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
||||
guidance = torch.tensor([1.0], **device_dtype_kwargs)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"pooled_projections": pooled_prompt_embeds,
|
||||
"timestep": timestep,
|
||||
"guidance": guidance,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenarios = [
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bf16",
|
||||
model_cls=FluxTransformer2DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
compile_kwargs={"fullgraph": True},
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bnb-nf4",
|
||||
model_cls=FluxTransformer2DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
"quantization_config": BitsAndBytesConfig(
|
||||
load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_quant_type="nf4"
|
||||
),
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-layerwise-upcasting",
|
||||
model_cls=FluxTransformer2DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-group-offload-leaf",
|
||||
model_cls=FluxTransformer2DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(
|
||||
model_init_fn,
|
||||
group_offload_kwargs={
|
||||
"onload_device": torch_device,
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
runner = BenchmarkMixin()
|
||||
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
||||
80
benchmarks/benchmarking_ltx.py
Normal file
80
benchmarks/benchmarking_ltx.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
||||
|
||||
from diffusers import LTXVideoTransformer3DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
CKPT_ID = "Lightricks/LTX-Video-0.9.7-dev"
|
||||
RESULT_FILENAME = "ltx.csv"
|
||||
|
||||
|
||||
def get_input_dict(**device_dtype_kwargs):
|
||||
# 512x704 (161 frames)
|
||||
# `max_sequence_length`: 256
|
||||
hidden_states = torch.randn(1, 7392, 128, **device_dtype_kwargs)
|
||||
encoder_hidden_states = torch.randn(1, 256, 4096, **device_dtype_kwargs)
|
||||
encoder_attention_mask = torch.ones(1, 256, **device_dtype_kwargs)
|
||||
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
||||
video_coords = torch.randn(1, 3, 7392, **device_dtype_kwargs)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"encoder_attention_mask": encoder_attention_mask,
|
||||
"timestep": timestep,
|
||||
"video_coords": video_coords,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenarios = [
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bf16",
|
||||
model_cls=LTXVideoTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
compile_kwargs={"fullgraph": True},
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-layerwise-upcasting",
|
||||
model_cls=LTXVideoTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-group-offload-leaf",
|
||||
model_cls=LTXVideoTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(
|
||||
model_init_fn,
|
||||
group_offload_kwargs={
|
||||
"onload_device": torch_device,
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
runner = BenchmarkMixin()
|
||||
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
||||
82
benchmarks/benchmarking_sdxl.py
Normal file
82
benchmarks/benchmarking_sdxl.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
CKPT_ID = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
RESULT_FILENAME = "sdxl.csv"
|
||||
|
||||
|
||||
def get_input_dict(**device_dtype_kwargs):
|
||||
# height: 1024
|
||||
# width: 1024
|
||||
# max_sequence_length: 77
|
||||
hidden_states = torch.randn(1, 4, 128, 128, **device_dtype_kwargs)
|
||||
encoder_hidden_states = torch.randn(1, 77, 2048, **device_dtype_kwargs)
|
||||
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.randn(1, 1280, **device_dtype_kwargs),
|
||||
"time_ids": torch.ones(1, 6, **device_dtype_kwargs),
|
||||
}
|
||||
|
||||
return {
|
||||
"sample": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"timestep": timestep,
|
||||
"added_cond_kwargs": added_cond_kwargs,
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenarios = [
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bf16",
|
||||
model_cls=UNet2DConditionModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "unet",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
compile_kwargs={"fullgraph": True},
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-layerwise-upcasting",
|
||||
model_cls=UNet2DConditionModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "unet",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-group-offload-leaf",
|
||||
model_cls=UNet2DConditionModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "unet",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(
|
||||
model_init_fn,
|
||||
group_offload_kwargs={
|
||||
"onload_device": torch_device,
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
runner = BenchmarkMixin()
|
||||
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
||||
244
benchmarks/benchmarking_utils.py
Normal file
244
benchmarks/benchmarking_utils.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import gc
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import threading
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
NUM_WARMUP_ROUNDS = 5
|
||||
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=1,
|
||||
)
|
||||
return float(f"{(t0.blocked_autorange().mean):.3f}")
|
||||
|
||||
|
||||
def flush():
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
# Adapted from https://github.com/lucasb-eyer/cnn_vit_benchmarks/blob/15b665ff758e8062131353076153905cae00a71f/main.py
|
||||
def calculate_flops(model, input_dict):
|
||||
try:
|
||||
from torchprofile import profile_macs
|
||||
except ModuleNotFoundError:
|
||||
raise
|
||||
|
||||
# This is a hacky way to convert the kwargs to args as `profile_macs` cries about kwargs.
|
||||
sig = inspect.signature(model.forward)
|
||||
param_names = [
|
||||
p.name
|
||||
for p in sig.parameters.values()
|
||||
if p.kind
|
||||
in (
|
||||
inspect.Parameter.POSITIONAL_ONLY,
|
||||
inspect.Parameter.POSITIONAL_OR_KEYWORD,
|
||||
)
|
||||
and p.name != "self"
|
||||
]
|
||||
bound = sig.bind_partial(**input_dict)
|
||||
bound.apply_defaults()
|
||||
args = tuple(bound.arguments[name] for name in param_names)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
macs = profile_macs(model, args)
|
||||
flops = 2 * macs # 1 MAC operation = 2 FLOPs (1 multiplication + 1 addition)
|
||||
return flops
|
||||
|
||||
|
||||
def calculate_params(model):
|
||||
return sum(p.numel() for p in model.parameters())
|
||||
|
||||
|
||||
# Users can define their own in case this doesn't suffice. For most cases,
|
||||
# it should be sufficient.
|
||||
def model_init_fn(model_cls, group_offload_kwargs=None, layerwise_upcasting=False, **init_kwargs):
|
||||
model = model_cls.from_pretrained(**init_kwargs).eval()
|
||||
if group_offload_kwargs and isinstance(group_offload_kwargs, dict):
|
||||
model.enable_group_offload(**group_offload_kwargs)
|
||||
else:
|
||||
model.to(torch_device)
|
||||
if layerwise_upcasting:
|
||||
model.enable_layerwise_casting(
|
||||
storage_dtype=torch.float8_e4m3fn, compute_dtype=init_kwargs.get("torch_dtype", torch.bfloat16)
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkScenario:
|
||||
name: str
|
||||
model_cls: ModelMixin
|
||||
model_init_kwargs: Dict[str, Any]
|
||||
model_init_fn: Callable
|
||||
get_model_input_dict: Callable
|
||||
compile_kwargs: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class BenchmarkMixin:
|
||||
def pre_benchmark(self):
|
||||
flush()
|
||||
torch.compiler.reset()
|
||||
|
||||
def post_benchmark(self, model):
|
||||
model.cpu()
|
||||
flush()
|
||||
torch.compiler.reset()
|
||||
|
||||
@torch.no_grad()
|
||||
def run_benchmark(self, scenario: BenchmarkScenario):
|
||||
# 0) Basic stats
|
||||
logger.info(f"Running scenario: {scenario.name}.")
|
||||
try:
|
||||
model = model_init_fn(scenario.model_cls, **scenario.model_init_kwargs)
|
||||
num_params = round(calculate_params(model) / 1e9, 2)
|
||||
try:
|
||||
flops = round(calculate_flops(model, input_dict=scenario.get_model_input_dict()) / 1e9, 2)
|
||||
except Exception as e:
|
||||
logger.info(f"Problem in calculating FLOPs:\n{e}")
|
||||
flops = None
|
||||
model.cpu()
|
||||
del model
|
||||
except Exception as e:
|
||||
logger.info(f"Error while initializing the model and calculating FLOPs:\n{e}")
|
||||
return {}
|
||||
self.pre_benchmark()
|
||||
|
||||
# 1) plain stats
|
||||
results = {}
|
||||
plain = None
|
||||
try:
|
||||
plain = self._run_phase(
|
||||
model_cls=scenario.model_cls,
|
||||
init_fn=scenario.model_init_fn,
|
||||
init_kwargs=scenario.model_init_kwargs,
|
||||
get_input_fn=scenario.get_model_input_dict,
|
||||
compile_kwargs=None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Benchmark could not be run with the following error:\n{e}")
|
||||
return results
|
||||
|
||||
# 2) compiled stats (if any)
|
||||
compiled = {"time": None, "memory": None}
|
||||
if scenario.compile_kwargs:
|
||||
try:
|
||||
compiled = self._run_phase(
|
||||
model_cls=scenario.model_cls,
|
||||
init_fn=scenario.model_init_fn,
|
||||
init_kwargs=scenario.model_init_kwargs,
|
||||
get_input_fn=scenario.get_model_input_dict,
|
||||
compile_kwargs=scenario.compile_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Compilation benchmark could not be run with the following error\n: {e}")
|
||||
if plain is None:
|
||||
return results
|
||||
|
||||
# 3) merge
|
||||
result = {
|
||||
"scenario": scenario.name,
|
||||
"model_cls": scenario.model_cls.__name__,
|
||||
"num_params_B": num_params,
|
||||
"flops_G": flops,
|
||||
"time_plain_s": plain["time"],
|
||||
"mem_plain_GB": plain["memory"],
|
||||
"time_compile_s": compiled["time"],
|
||||
"mem_compile_GB": compiled["memory"],
|
||||
}
|
||||
if scenario.compile_kwargs:
|
||||
result["fullgraph"] = scenario.compile_kwargs.get("fullgraph", False)
|
||||
result["mode"] = scenario.compile_kwargs.get("mode", "default")
|
||||
else:
|
||||
result["fullgraph"], result["mode"] = None, None
|
||||
return result
|
||||
|
||||
def run_bencmarks_and_collate(self, scenarios: Union[BenchmarkScenario, list[BenchmarkScenario]], filename: str):
|
||||
if not isinstance(scenarios, list):
|
||||
scenarios = [scenarios]
|
||||
record_queue = queue.Queue()
|
||||
stop_signal = object()
|
||||
|
||||
def _writer_thread():
|
||||
while True:
|
||||
item = record_queue.get()
|
||||
if item is stop_signal:
|
||||
break
|
||||
df_row = pd.DataFrame([item])
|
||||
write_header = not os.path.exists(filename)
|
||||
df_row.to_csv(filename, mode="a", header=write_header, index=False)
|
||||
record_queue.task_done()
|
||||
|
||||
record_queue.task_done()
|
||||
|
||||
writer = threading.Thread(target=_writer_thread, daemon=True)
|
||||
writer.start()
|
||||
|
||||
for s in scenarios:
|
||||
try:
|
||||
record = self.run_benchmark(s)
|
||||
if record:
|
||||
record_queue.put(record)
|
||||
else:
|
||||
logger.info(f"Record empty from scenario: {s.name}.")
|
||||
except Exception as e:
|
||||
logger.info(f"Running scenario ({s.name}) led to error:\n{e}")
|
||||
record_queue.put(stop_signal)
|
||||
logger.info(f"Results serialized to {filename=}.")
|
||||
|
||||
def _run_phase(
|
||||
self,
|
||||
*,
|
||||
model_cls: ModelMixin,
|
||||
init_fn: Callable,
|
||||
init_kwargs: Dict[str, Any],
|
||||
get_input_fn: Callable,
|
||||
compile_kwargs: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, float]:
|
||||
# setup
|
||||
self.pre_benchmark()
|
||||
|
||||
# init & (optional) compile
|
||||
model = init_fn(model_cls, **init_kwargs)
|
||||
if compile_kwargs:
|
||||
model.compile(**compile_kwargs)
|
||||
|
||||
# build inputs
|
||||
inp = get_input_fn()
|
||||
|
||||
# measure
|
||||
run_ctx = torch._inductor.utils.fresh_inductor_cache() if compile_kwargs else nullcontext()
|
||||
with run_ctx:
|
||||
for _ in range(NUM_WARMUP_ROUNDS):
|
||||
_ = model(**inp)
|
||||
time_s = benchmark_fn(lambda m, d: m(**d), model, inp)
|
||||
mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
||||
mem_gb = round(mem_gb, 2)
|
||||
|
||||
# teardown
|
||||
self.post_benchmark(model)
|
||||
del model
|
||||
return {"time": time_s, "memory": mem_gb}
|
||||
74
benchmarks/benchmarking_wan.py
Normal file
74
benchmarks/benchmarking_wan.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from benchmarking_utils import BenchmarkMixin, BenchmarkScenario, model_init_fn
|
||||
|
||||
from diffusers import WanTransformer3DModel
|
||||
from diffusers.utils.testing_utils import torch_device
|
||||
|
||||
|
||||
CKPT_ID = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
|
||||
RESULT_FILENAME = "wan.csv"
|
||||
|
||||
|
||||
def get_input_dict(**device_dtype_kwargs):
|
||||
# height: 480
|
||||
# width: 832
|
||||
# num_frames: 81
|
||||
# max_sequence_length: 512
|
||||
hidden_states = torch.randn(1, 16, 21, 60, 104, **device_dtype_kwargs)
|
||||
encoder_hidden_states = torch.randn(1, 512, 4096, **device_dtype_kwargs)
|
||||
timestep = torch.tensor([1.0], **device_dtype_kwargs)
|
||||
|
||||
return {"hidden_states": hidden_states, "encoder_hidden_states": encoder_hidden_states, "timestep": timestep}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
scenarios = [
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-bf16",
|
||||
model_cls=WanTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=model_init_fn,
|
||||
compile_kwargs={"fullgraph": True},
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-layerwise-upcasting",
|
||||
model_cls=WanTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(model_init_fn, layerwise_upcasting=True),
|
||||
),
|
||||
BenchmarkScenario(
|
||||
name=f"{CKPT_ID}-group-offload-leaf",
|
||||
model_cls=WanTransformer3DModel,
|
||||
model_init_kwargs={
|
||||
"pretrained_model_name_or_path": CKPT_ID,
|
||||
"torch_dtype": torch.bfloat16,
|
||||
"subfolder": "transformer",
|
||||
},
|
||||
get_model_input_dict=partial(get_input_dict, device=torch_device, dtype=torch.bfloat16),
|
||||
model_init_fn=partial(
|
||||
model_init_fn,
|
||||
group_offload_kwargs={
|
||||
"onload_device": torch_device,
|
||||
"offload_device": torch.device("cpu"),
|
||||
"offload_type": "leaf_level",
|
||||
"use_stream": True,
|
||||
"non_blocking": True,
|
||||
},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
runner = BenchmarkMixin()
|
||||
runner.run_bencmarks_and_collate(scenarios, filename=RESULT_FILENAME)
|
||||
166
benchmarks/populate_into_db.py
Normal file
166
benchmarks/populate_into_db.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gpustat
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
|
||||
register_adapter(dict, Json)
|
||||
|
||||
FINAL_CSV_FILENAME = "collated_results.csv"
|
||||
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
|
||||
BENCHMARKS_TABLE_NAME = "benchmarks"
|
||||
MEASUREMENTS_TABLE_NAME = "model_measurements"
|
||||
|
||||
|
||||
def _init_benchmark(conn, branch, commit_id, commit_msg):
|
||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||
metadata = {"gpu_name": gpu_stats[0]["name"]}
|
||||
repository = "huggingface/diffusers"
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
|
||||
(repository, branch, commit_id, commit_msg, metadata),
|
||||
)
|
||||
benchmark_id = cur.fetchone()[0]
|
||||
print(f"Initialised benchmark #{benchmark_id}")
|
||||
return benchmark_id
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"branch",
|
||||
type=str,
|
||||
help="The branch name on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_id",
|
||||
type=str,
|
||||
help="The commit hash on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_msg",
|
||||
type=str,
|
||||
help="The commit message associated with the commit, truncated to 70 characters.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv("PGHOST"),
|
||||
database=os.getenv("PGDATABASE"),
|
||||
user=os.getenv("PGUSER"),
|
||||
password=os.getenv("PGPASSWORD"),
|
||||
)
|
||||
print("DB connection established successfully.")
|
||||
except Exception as e:
|
||||
print(f"Problem during DB init: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
benchmark_id = _init_benchmark(
|
||||
conn=conn,
|
||||
branch=args.branch,
|
||||
commit_id=args.commit_id,
|
||||
commit_msg=args.commit_msg,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Problem during initializing benchmark: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
df = pd.read_csv(FINAL_CSV_FILENAME)
|
||||
|
||||
# Helper to cast values (or None) given a dtype
|
||||
def _cast_value(val, dtype: str):
|
||||
if pd.isna(val):
|
||||
return None
|
||||
|
||||
if dtype == "text":
|
||||
return str(val).strip()
|
||||
|
||||
if dtype == "float":
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if dtype == "bool":
|
||||
s = str(val).strip().lower()
|
||||
if s in ("true", "t", "yes", "1"):
|
||||
return True
|
||||
if s in ("false", "f", "no", "0"):
|
||||
return False
|
||||
if val in (1, 1.0):
|
||||
return True
|
||||
if val in (0, 0.0):
|
||||
return False
|
||||
return None
|
||||
|
||||
return val
|
||||
|
||||
try:
|
||||
rows_to_insert = []
|
||||
for _, row in df.iterrows():
|
||||
scenario = _cast_value(row.get("scenario"), "text")
|
||||
model_cls = _cast_value(row.get("model_cls"), "text")
|
||||
num_params_B = _cast_value(row.get("num_params_B"), "float")
|
||||
flops_G = _cast_value(row.get("flops_G"), "float")
|
||||
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
|
||||
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
|
||||
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
|
||||
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
|
||||
fullgraph = _cast_value(row.get("fullgraph"), "bool")
|
||||
mode = _cast_value(row.get("mode"), "text")
|
||||
|
||||
# If "github_sha" column exists in the CSV, cast it; else default to None
|
||||
if "github_sha" in df.columns:
|
||||
github_sha = _cast_value(row.get("github_sha"), "text")
|
||||
else:
|
||||
github_sha = None
|
||||
|
||||
measurements = {
|
||||
"scenario": scenario,
|
||||
"model_cls": model_cls,
|
||||
"num_params_B": num_params_B,
|
||||
"flops_G": flops_G,
|
||||
"time_plain_s": time_plain_s,
|
||||
"mem_plain_GB": mem_plain_GB,
|
||||
"time_compile_s": time_compile_s,
|
||||
"mem_compile_GB": mem_compile_GB,
|
||||
"fullgraph": fullgraph,
|
||||
"mode": mode,
|
||||
"github_sha": github_sha,
|
||||
}
|
||||
rows_to_insert.append((benchmark_id, measurements))
|
||||
|
||||
# Batch-insert all rows
|
||||
insert_sql = f"""
|
||||
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
|
||||
benchmark_id,
|
||||
measurements
|
||||
)
|
||||
VALUES (%s, %s);
|
||||
"""
|
||||
|
||||
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
|
||||
conn.commit()
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
sys.exit(1)
|
||||
@@ -1,19 +1,19 @@
|
||||
import glob
|
||||
import sys
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from huggingface_hub.utils import EntryNotFoundError
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from utils import BASE_PATH, FINAL_CSV_FILE, GITHUB_SHA, REPO_ID, collate_csv # noqa: E402
|
||||
REPO_ID = "diffusers/benchmarks"
|
||||
|
||||
|
||||
def has_previous_benchmark() -> str:
|
||||
from run_all import FINAL_CSV_FILENAME
|
||||
|
||||
csv_path = None
|
||||
try:
|
||||
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILE)
|
||||
csv_path = hf_hub_download(repo_id=REPO_ID, repo_type="dataset", filename=FINAL_CSV_FILENAME)
|
||||
except EntryNotFoundError:
|
||||
csv_path = None
|
||||
return csv_path
|
||||
@@ -26,46 +26,50 @@ def filter_float(value):
|
||||
|
||||
|
||||
def push_to_hf_dataset():
|
||||
all_csvs = sorted(glob.glob(f"{BASE_PATH}/*.csv"))
|
||||
collate_csv(all_csvs, FINAL_CSV_FILE)
|
||||
from run_all import FINAL_CSV_FILENAME, GITHUB_SHA
|
||||
|
||||
# If there's an existing benchmark file, we should report the changes.
|
||||
csv_path = has_previous_benchmark()
|
||||
if csv_path is not None:
|
||||
current_results = pd.read_csv(FINAL_CSV_FILE)
|
||||
current_results = pd.read_csv(FINAL_CSV_FILENAME)
|
||||
previous_results = pd.read_csv(csv_path)
|
||||
|
||||
numeric_columns = current_results.select_dtypes(include=["float64", "int64"]).columns
|
||||
numeric_columns = [
|
||||
c for c in numeric_columns if c not in ["batch_size", "num_inference_steps", "actual_gpu_memory (gbs)"]
|
||||
]
|
||||
|
||||
for column in numeric_columns:
|
||||
previous_results[column] = previous_results[column].map(lambda x: filter_float(x))
|
||||
# get previous values as floats, aligned to current index
|
||||
prev_vals = previous_results[column].map(filter_float).reindex(current_results.index)
|
||||
|
||||
# Calculate the percentage change
|
||||
current_results[column] = current_results[column].astype(float)
|
||||
previous_results[column] = previous_results[column].astype(float)
|
||||
percent_change = ((current_results[column] - previous_results[column]) / previous_results[column]) * 100
|
||||
# get current values as floats
|
||||
curr_vals = current_results[column].astype(float)
|
||||
|
||||
# Format the values with '+' or '-' sign and append to original values
|
||||
current_results[column] = current_results[column].map(str) + percent_change.map(
|
||||
lambda x: f" ({'+' if x > 0 else ''}{x:.2f}%)"
|
||||
# stringify the current values
|
||||
curr_str = curr_vals.map(str)
|
||||
|
||||
# build an appendage only when prev exists and differs
|
||||
append_str = prev_vals.where(prev_vals.notnull() & (prev_vals != curr_vals), other=pd.NA).map(
|
||||
lambda x: f" ({x})" if pd.notnull(x) else ""
|
||||
)
|
||||
# There might be newly added rows. So, filter out the NaNs.
|
||||
current_results[column] = current_results[column].map(lambda x: x.replace(" (nan%)", ""))
|
||||
|
||||
# Overwrite the current result file.
|
||||
current_results.to_csv(FINAL_CSV_FILE, index=False)
|
||||
# combine
|
||||
current_results[column] = curr_str + append_str
|
||||
os.remove(FINAL_CSV_FILENAME)
|
||||
current_results.to_csv(FINAL_CSV_FILENAME, index=False)
|
||||
|
||||
commit_message = f"upload from sha: {GITHUB_SHA}" if GITHUB_SHA is not None else "upload benchmark results"
|
||||
upload_file(
|
||||
repo_id=REPO_ID,
|
||||
path_in_repo=FINAL_CSV_FILE,
|
||||
path_or_fileobj=FINAL_CSV_FILE,
|
||||
path_in_repo=FINAL_CSV_FILENAME,
|
||||
path_or_fileobj=FINAL_CSV_FILENAME,
|
||||
repo_type="dataset",
|
||||
commit_message=commit_message,
|
||||
)
|
||||
upload_file(
|
||||
repo_id="diffusers/benchmark-analyzer",
|
||||
path_in_repo=FINAL_CSV_FILENAME,
|
||||
path_or_fileobj=FINAL_CSV_FILENAME,
|
||||
repo_type="space",
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
6
benchmarks/requirements.txt
Normal file
6
benchmarks/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
pandas
|
||||
psutil
|
||||
gpustat
|
||||
torchprofile
|
||||
bitsandbytes
|
||||
psycopg2==2.9.9
|
||||
@@ -1,101 +1,84 @@
|
||||
import glob
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
sys.path.append(".")
|
||||
from benchmark_text_to_image import ALL_T2I_CKPTS # noqa: E402
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
PATTERN = "benchmark_*.py"
|
||||
PATTERN = "benchmarking_*.py"
|
||||
FINAL_CSV_FILENAME = "collated_results.csv"
|
||||
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
|
||||
|
||||
|
||||
class SubprocessCallException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# Taken from `test_examples_utils.py`
|
||||
def run_command(command: List[str], return_stdout=False):
|
||||
"""
|
||||
Runs `command` with `subprocess.check_output` and will potentially return the `stdout`. Will also properly capture
|
||||
if an error occurred while running `command`
|
||||
"""
|
||||
def run_command(command: list[str], return_stdout=False):
|
||||
try:
|
||||
output = subprocess.check_output(command, stderr=subprocess.STDOUT)
|
||||
if return_stdout:
|
||||
if hasattr(output, "decode"):
|
||||
output = output.decode("utf-8")
|
||||
return output
|
||||
if return_stdout and hasattr(output, "decode"):
|
||||
return output.decode("utf-8")
|
||||
except subprocess.CalledProcessError as e:
|
||||
raise SubprocessCallException(
|
||||
f"Command `{' '.join(command)}` failed with the following error:\n\n{e.output.decode()}"
|
||||
) from e
|
||||
raise SubprocessCallException(f"Command `{' '.join(command)}` failed with:\n{e.output.decode()}") from e
|
||||
|
||||
|
||||
def main():
|
||||
python_files = glob.glob(PATTERN)
|
||||
def merge_csvs(final_csv: str = "collated_results.csv"):
|
||||
all_csvs = glob.glob("*.csv")
|
||||
all_csvs = [f for f in all_csvs if f != final_csv]
|
||||
if not all_csvs:
|
||||
logger.info("No result CSVs found to merge.")
|
||||
return
|
||||
|
||||
for file in python_files:
|
||||
print(f"****** Running file: {file} ******")
|
||||
|
||||
# Run with canonical settings.
|
||||
if file != "benchmark_text_to_image.py" and file != "benchmark_ip_adapters.py":
|
||||
command = f"python {file}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
# Run variants.
|
||||
for file in python_files:
|
||||
# See: https://github.com/pytorch/pytorch/issues/129637
|
||||
if file == "benchmark_ip_adapters.py":
|
||||
df_list = []
|
||||
for f in all_csvs:
|
||||
try:
|
||||
d = pd.read_csv(f)
|
||||
except pd.errors.EmptyDataError:
|
||||
# If a file existed but was zero‐bytes or corrupted, skip it
|
||||
continue
|
||||
df_list.append(d)
|
||||
|
||||
if file == "benchmark_text_to_image.py":
|
||||
for ckpt in ALL_T2I_CKPTS:
|
||||
command = f"python {file} --ckpt {ckpt}"
|
||||
if not df_list:
|
||||
logger.info("All result CSVs were empty or invalid; nothing to merge.")
|
||||
return
|
||||
|
||||
if "turbo" in ckpt:
|
||||
command += " --num_inference_steps 1"
|
||||
final_df = pd.concat(df_list, ignore_index=True)
|
||||
if GITHUB_SHA is not None:
|
||||
final_df["github_sha"] = GITHUB_SHA
|
||||
final_df.to_csv(final_csv, index=False)
|
||||
logger.info(f"Merged {len(all_csvs)} partial CSVs → {final_csv}.")
|
||||
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
def run_scripts():
|
||||
python_files = sorted(glob.glob(PATTERN))
|
||||
python_files = [f for f in python_files if f != "benchmarking_utils.py"]
|
||||
|
||||
elif file == "benchmark_sd_img.py":
|
||||
for ckpt in ["stabilityai/stable-diffusion-xl-refiner-1.0", "stabilityai/sdxl-turbo"]:
|
||||
command = f"python {file} --ckpt {ckpt}"
|
||||
for file in python_files:
|
||||
script_name = file.split(".py")[0].split("_")[-1] # example: benchmarking_foo.py -> foo
|
||||
logger.info(f"\n****** Running file: {file} ******")
|
||||
|
||||
if ckpt == "stabilityai/sdxl-turbo":
|
||||
command += " --num_inference_steps 2"
|
||||
partial_csv = f"{script_name}.csv"
|
||||
if os.path.exists(partial_csv):
|
||||
logger.info(f"Found {partial_csv}. Removing for safer numbers and duplication.")
|
||||
os.remove(partial_csv)
|
||||
|
||||
run_command(command.split())
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
command = ["python", file]
|
||||
try:
|
||||
run_command(command)
|
||||
logger.info(f"→ {file} finished normally.")
|
||||
except SubprocessCallException as e:
|
||||
logger.info(f"Error running {file}:\n{e}")
|
||||
finally:
|
||||
logger.info(f"→ Merging partial CSVs after {file} …")
|
||||
merge_csvs(final_csv=FINAL_CSV_FILENAME)
|
||||
|
||||
elif file in ["benchmark_sd_inpainting.py", "benchmark_ip_adapters.py"]:
|
||||
sdxl_ckpt = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
command = f"python {file} --ckpt {sdxl_ckpt}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
|
||||
elif file in ["benchmark_controlnet.py", "benchmark_t2i_adapter.py"]:
|
||||
sdxl_ckpt = (
|
||||
"diffusers/controlnet-canny-sdxl-1.0"
|
||||
if "controlnet" in file
|
||||
else "TencentARC/t2i-adapter-canny-sdxl-1.0"
|
||||
)
|
||||
command = f"python {file} --ckpt {sdxl_ckpt}"
|
||||
run_command(command.split())
|
||||
|
||||
command += " --run_compile"
|
||||
run_command(command.split())
|
||||
logger.info(f"\nAll scripts attempted. Final collated CSV: {FINAL_CSV_FILENAME}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
run_scripts()
|
||||
|
||||
@@ -1,98 +0,0 @@
|
||||
import argparse
|
||||
import csv
|
||||
import gc
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
|
||||
GITHUB_SHA = os.getenv("GITHUB_SHA", None)
|
||||
BENCHMARK_FIELDS = [
|
||||
"pipeline_cls",
|
||||
"ckpt_id",
|
||||
"batch_size",
|
||||
"num_inference_steps",
|
||||
"model_cpu_offload",
|
||||
"run_compile",
|
||||
"time (secs)",
|
||||
"memory (gbs)",
|
||||
"actual_gpu_memory (gbs)",
|
||||
"github_sha",
|
||||
]
|
||||
|
||||
PROMPT = "ghibli style, a fantasy landscape with castles"
|
||||
BASE_PATH = os.getenv("BASE_PATH", ".")
|
||||
TOTAL_GPU_MEMORY = float(os.getenv("TOTAL_GPU_MEMORY", torch.cuda.get_device_properties(0).total_memory / (1024**3)))
|
||||
|
||||
REPO_ID = "diffusers/benchmarks"
|
||||
FINAL_CSV_FILE = "collated_results.csv"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkInfo:
|
||||
time: float
|
||||
memory: float
|
||||
|
||||
|
||||
def flush():
|
||||
"""Wipes off memory."""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
|
||||
def bytes_to_giga_bytes(bytes):
|
||||
return f"{(bytes / 1024 / 1024 / 1024):.3f}"
|
||||
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
|
||||
|
||||
def generate_csv_dict(
|
||||
pipeline_cls: str, ckpt: str, args: argparse.Namespace, benchmark_info: BenchmarkInfo
|
||||
) -> Dict[str, Union[str, bool, float]]:
|
||||
"""Packs benchmarking data into a dictionary for latter serialization."""
|
||||
data_dict = {
|
||||
"pipeline_cls": pipeline_cls,
|
||||
"ckpt_id": ckpt,
|
||||
"batch_size": args.batch_size,
|
||||
"num_inference_steps": args.num_inference_steps,
|
||||
"model_cpu_offload": args.model_cpu_offload,
|
||||
"run_compile": args.run_compile,
|
||||
"time (secs)": benchmark_info.time,
|
||||
"memory (gbs)": benchmark_info.memory,
|
||||
"actual_gpu_memory (gbs)": f"{(TOTAL_GPU_MEMORY):.3f}",
|
||||
"github_sha": GITHUB_SHA,
|
||||
}
|
||||
return data_dict
|
||||
|
||||
|
||||
def write_to_csv(file_name: str, data_dict: Dict[str, Union[str, bool, float]]):
|
||||
"""Serializes a dictionary into a CSV file."""
|
||||
with open(file_name, mode="w", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=BENCHMARK_FIELDS)
|
||||
writer.writeheader()
|
||||
writer.writerow(data_dict)
|
||||
|
||||
|
||||
def collate_csv(input_files: List[str], output_file: str):
|
||||
"""Collates multiple identically structured CSVs into a single CSV file."""
|
||||
with open(output_file, mode="w", newline="") as outfile:
|
||||
writer = csv.DictWriter(outfile, fieldnames=BENCHMARK_FIELDS)
|
||||
writer.writeheader()
|
||||
|
||||
for file in input_files:
|
||||
with open(file, mode="r") as infile:
|
||||
reader = csv.DictReader(infile)
|
||||
for row in reader:
|
||||
writer.writerow(row)
|
||||
@@ -64,6 +64,8 @@
|
||||
title: Overview
|
||||
- local: using-diffusers/create_a_server
|
||||
title: Create a server
|
||||
- local: using-diffusers/batched_inference
|
||||
title: Batch inference
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
- local: using-diffusers/scheduler_features
|
||||
@@ -91,6 +93,16 @@
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
title: Hybrid Inference
|
||||
- sections:
|
||||
- local: modular_diffusers/getting_started
|
||||
title: Getting Started
|
||||
- local: modular_diffusers/components_manager
|
||||
title: Components Manager
|
||||
- local: modular_diffusers/write_own_pipeline_block
|
||||
title: Write your own pipeline block
|
||||
- local: modular_diffusers/end_to_end_guide
|
||||
title: End-to-End Developer Guide
|
||||
title: Modular Diffusers
|
||||
- sections:
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
|
||||
@@ -28,3 +28,9 @@ Cache methods speedup diffusion transformers by storing and reusing intermediate
|
||||
[[autodoc]] FasterCacheConfig
|
||||
|
||||
[[autodoc]] apply_faster_cache
|
||||
|
||||
### FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] FirstBlockCacheConfig
|
||||
|
||||
[[autodoc]] apply_first_block_cache
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# aMUSEd
|
||||
|
||||
aMUSEd was introduced in [aMUSEd: An Open MUSE Reproduction](https://huggingface.co/papers/2401.01808) by Suraj Patil, William Berman, Robin Rombach, and Patrick von Platen.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Attend-and-Excite
|
||||
|
||||
Attend-and-Excite for Stable Diffusion was proposed in [Attend-and-Excite: Attention-Based Semantic Guidance for Text-to-Image Diffusion Models](https://attendandexcite.github.io/Attend-and-Excite/) and provides textual attention control over image generation.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# AudioLDM
|
||||
|
||||
AudioLDM was proposed in [AudioLDM: Text-to-Audio Generation with Latent Diffusion Models](https://huggingface.co/papers/2301.12503) by Haohe Liu et al. Inspired by [Stable Diffusion](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/overview), AudioLDM
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# BLIP-Diffusion
|
||||
|
||||
BLIP-Diffusion was proposed in [BLIP-Diffusion: Pre-trained Subject Representation for Controllable Text-to-Image Generation and Editing](https://huggingface.co/papers/2305.14720). It enables zero-shot subject-driven generation and control-guided zero-shot generation.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# ControlNet-XS
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# ControlNet-XS with Stable Diffusion XL
|
||||
|
||||
ControlNet-XS was introduced in [ControlNet-XS](https://vislearn.github.io/ControlNet-XS/) by Denis Zavadski and Carsten Rother. It is based on the observation that the control model in the [original ControlNet](https://huggingface.co/papers/2302.05543) can be made much smaller and still produce good results.
|
||||
|
||||
@@ -24,6 +24,31 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
|
||||
|
||||
</Tip>
|
||||
|
||||
## Loading original format checkpoints
|
||||
|
||||
Original format checkpoints that have not been converted to diffusers-expected format can be loaded using the `from_single_file` method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import Cosmos2TextToImagePipeline, CosmosTransformer3DModel
|
||||
|
||||
model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
|
||||
transformer = CosmosTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/nvidia/Cosmos-Predict2-2B-Text2Image/blob/main/model.pt",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to("cuda")
|
||||
pipe = Cosmos2TextToImagePipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
|
||||
negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
|
||||
|
||||
output = pipe(
|
||||
prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
|
||||
).images[0]
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Dance Diffusion
|
||||
|
||||
[Dance Diffusion](https://github.com/Harmonai-org/sample-generator) is by Zach Evans.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# DiffEdit
|
||||
|
||||
[DiffEdit: Diffusion-based semantic image editing with mask guidance](https://huggingface.co/papers/2210.11427) is by Guillaume Couairon, Jakob Verbeek, Holger Schwenk, and Matthieu Cord.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# I2VGen-XL
|
||||
|
||||
[I2VGen-XL: High-Quality Image-to-Video Synthesis via Cascaded Diffusion Models](https://hf.co/papers/2311.04145.pdf) by Shiwei Zhang, Jiayu Wang, Yingya Zhang, Kang Zhao, Hangjie Yuan, Zhiwu Qin, Xiang Wang, Deli Zhao, and Jingren Zhou.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# MusicLDM
|
||||
|
||||
MusicLDM was proposed in [MusicLDM: Enhancing Novelty in Text-to-Music Generation Using Beat-Synchronous Mixup Strategies](https://huggingface.co/papers/2308.01546) by Ke Chen, Yusong Wu, Haohe Liu, Marianna Nezhurina, Taylor Berg-Kirkpatrick, Shlomo Dubnov.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Paint by Example
|
||||
|
||||
[Paint by Example: Exemplar-based Image Editing with Diffusion Models](https://huggingface.co/papers/2211.13227) is by Binxin Yang, Shuyang Gu, Bo Zhang, Ting Zhang, Xuejin Chen, Xiaoyan Sun, Dong Chen, Fang Wen.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# MultiDiffusion
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Image-to-Video Generation with PIA (Personalized Image Animator)
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Self-Attention Guidance
|
||||
|
||||
[Improving Sample Quality of Diffusion Models Using Self-Attention Guidance](https://huggingface.co/papers/2210.00939) is by Susung Hong et al.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Semantic Guidance
|
||||
|
||||
Semantic Guidance for Diffusion Models was proposed in [SEGA: Instructing Text-to-Image Models using Semantic Guidance](https://huggingface.co/papers/2301.12247) and provides strong semantic control over image generation.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# GLIGEN (Grounded Language-to-Image Generation)
|
||||
|
||||
The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] and [`StableDiffusionGLIGENTextImagePipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes with [`StableDiffusionGLIGENPipeline`], if input images are given, [`StableDiffusionGLIGENTextImagePipeline`] can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# K-Diffusion
|
||||
|
||||
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Text-to-(RGB, depth)
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Safe Stable Diffusion
|
||||
|
||||
Safe Stable Diffusion was proposed in [Safe Latent Diffusion: Mitigating Inappropriate Degeneration in Diffusion Models](https://huggingface.co/papers/2211.05105) and mitigates inappropriate degeneration from Stable Diffusion models because they're trained on unfiltered web-crawled datasets. For instance Stable Diffusion may unexpectedly generate nudity, violence, images depicting self-harm, and otherwise offensive content. Safe Stable Diffusion is an extension of Stable Diffusion that drastically reduces this type of content.
|
||||
|
||||
@@ -10,11 +10,8 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 This pipeline is for research purposes only.
|
||||
|
||||
</Tip>
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Text-to-video
|
||||
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# Text2Video-Zero
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -7,6 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# unCLIP
|
||||
|
||||
[Hierarchical Text-Conditional Image Generation with CLIP Latents](https://huggingface.co/papers/2204.06125) is by Aditya Ramesh, Prafulla Dhariwal, Alex Nichol, Casey Chu, Mark Chen. The unCLIP model in 🤗 Diffusers comes from kakaobrain's [karlo](https://github.com/kakaobrain/karlo).
|
||||
|
||||
@@ -10,6 +10,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# UniDiffuser
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
|
||||
@@ -302,12 +302,12 @@ The general rule of thumb to keep in mind when preparing inputs for the VACE pip
|
||||
```py
|
||||
# pip install ftfy
|
||||
import torch
|
||||
from diffusers import WanPipeline, AutoModel
|
||||
from diffusers import WanPipeline, WanTransformer3DModel, AutoencoderKLWan
|
||||
|
||||
vae = AutoModel.from_single_file(
|
||||
vae = AutoencoderKLWan.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/vae/wan_2.1_vae.safetensors"
|
||||
)
|
||||
transformer = AutoModel.from_single_file(
|
||||
transformer = WanTransformer3DModel.from_single_file(
|
||||
"https://huggingface.co/Comfy-Org/Wan_2.1_ComfyUI_repackaged/blob/main/split_files/diffusion_models/wan2.1_t2v_1.3B_bf16.safetensors",
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
@@ -12,6 +12,9 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# Würstchen
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</div>
|
||||
|
||||
510
docs/source/en/modular_diffusers/components_manager.md
Normal file
510
docs/source/en/modular_diffusers/components_manager.md
Normal file
@@ -0,0 +1,510 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Components Manager
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 **Experimental Feature**: This is an experimental feature we are actively developing. The API may be subject to breaking changes.
|
||||
|
||||
</Tip>
|
||||
|
||||
The Components Manager is a central model registry and management system in diffusers. It lets you add models then reuse them across multiple pipelines and workflows. It tracks all models in one place with useful metadata such as model size, device placement and loaded adapters (LoRA, IP-Adapter). It has mechanisms in place to prevent duplicate model instances, enables memory-efficient sharing. Most significantly, it offers offloading that works across pipelines — unlike regular DiffusionPipeline offloading which is limited to one pipeline with predefined sequences, the Components Manager automatically manages your device memory across all your models and workflows.
|
||||
|
||||
|
||||
## Basic Operations
|
||||
|
||||
Let's start with the fundamental operations. First, create a Components Manager:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentsManager
|
||||
comp = ComponentsManager()
|
||||
```
|
||||
|
||||
Use the `add(name, component)` method to register a component. It returns a unique ID that combines the component name with the object's unique identifier (using Python's `id()` function):
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
text_encoder = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
# Returns component_id like 'text_encoder_139917733042864'
|
||||
component_id = comp.add("text_encoder", text_encoder)
|
||||
```
|
||||
|
||||
You can view all registered components and their metadata:
|
||||
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
===============================================================================================================================================
|
||||
Models:
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
And remove components using their unique ID:
|
||||
|
||||
```py
|
||||
comp.remove("text_encoder_139917733042864")
|
||||
```
|
||||
|
||||
## Duplicate Detection
|
||||
|
||||
The Components Manager automatically detects and prevents duplicate model instances to save memory and avoid confusion. Let's walk through how this works in practice.
|
||||
|
||||
When you try to add the same object twice, the manager will warn you and return the existing ID:
|
||||
|
||||
```py
|
||||
>>> comp.add("text_encoder", text_encoder)
|
||||
'text_encoder_139917733042864'
|
||||
>>> comp.add("text_encoder", text_encoder)
|
||||
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917733042864'
|
||||
'text_encoder_139917733042864'
|
||||
```
|
||||
|
||||
Even if you add the same object under a different name, it will still be detected as a duplicate:
|
||||
|
||||
```py
|
||||
>>> comp.add("clip", text_encoder)
|
||||
ComponentsManager: adding component 'clip' as 'clip_139917733042864', but it is duplicate of 'text_encoder_139917733042864'
|
||||
To remove a duplicate, call `components_manager.remove('<component_id>')`.
|
||||
'clip_139917733042864'
|
||||
```
|
||||
|
||||
However, there's a more subtle case where duplicate detection becomes tricky. When you load the same model into different objects, the manager can't detect duplicates unless you use `ComponentSpec`. For example:
|
||||
|
||||
```py
|
||||
>>> text_encoder_2 = AutoModel.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder")
|
||||
>>> comp.add("text_encoder", text_encoder_2)
|
||||
'text_encoder_139917732983664'
|
||||
```
|
||||
|
||||
This creates a problem - you now have two copies of the same model consuming double the memory:
|
||||
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
===============================================================================================================================================
|
||||
Models:
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
|
||||
clip_139917733042864 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
|
||||
text_encoder_139917732983664 | CLIPTextModel | cpu | torch.float32 | 0.46 | N/A | N/A
|
||||
-----------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
We recommend using `ComponentSpec` to load your models. Models loaded with `ComponentSpec` get tagged with a unique ID that encodes their loading parameters, allowing the Components Manager to detect when different objects represent the same underlying checkpoint:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ComponentsManager
|
||||
from transformers import CLIPTextModel
|
||||
comp = ComponentsManager()
|
||||
|
||||
# Create ComponentSpec for the first text encoder
|
||||
spec = ComponentSpec(name="text_encoder", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a duplicate text encoder (it is same checkpoint, from same repo/subfolder)
|
||||
spec_duplicated = ComponentSpec(name="text_encoder_duplicated", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="text_encoder", type_hint=CLIPTextModel)
|
||||
|
||||
# Load and add both components - the manager will detect they're the same model
|
||||
comp.add("text_encoder", spec.load())
|
||||
comp.add("text_encoder_duplicated", spec_duplicated.load())
|
||||
```
|
||||
|
||||
Now the manager detects the duplicate and warns you:
|
||||
|
||||
```out
|
||||
ComponentsManager: adding component 'text_encoder_duplicated_139917580682672', but it has duplicate load_id 'stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null' with existing components: text_encoder_139918506246832. To remove a duplicate, call `components_manager.remove('<component_id>')`.
|
||||
'text_encoder_duplicated_139917580682672'
|
||||
```
|
||||
|
||||
Both models now show the same `load_id`, making it clear they're the same model:
|
||||
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
======================================================================================================================================================================================================
|
||||
Models:
|
||||
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_139918506246832 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
|
||||
text_encoder_duplicated_139917580682672 | CLIPTextModel | cpu | torch.float32 | 0.46 | stabilityai/stable-diffusion-xl-base-1.0|text_encoder|null|null | N/A
|
||||
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
## Collections
|
||||
|
||||
Collections are labels you can assign to components for better organization and management. You add a component under a collection by passing the `collection=` parameter when you add the component to the manager, i.e. `add(name, component, collection=...)`. Within each collection, only one component per name is allowed - if you add a second component with the same name, the first one is automatically removed.
|
||||
|
||||
Here's how collections work in practice:
|
||||
|
||||
```py
|
||||
comp = ComponentsManager()
|
||||
# Create ComponentSpec for the first UNet (SDXL base)
|
||||
spec = ComponentSpec(name="unet", repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", type_hint=AutoModel)
|
||||
# Create ComponentSpec for a different UNet (Juggernaut-XL)
|
||||
spec2 = ComponentSpec(name="unet", repo="RunDiffusion/Juggernaut-XL-v9", subfolder="unet", type_hint=AutoModel, variant="fp16")
|
||||
|
||||
# Add both UNets to the same collection - the second one will replace the first
|
||||
comp.add("unet", spec.load(), collection="sdxl")
|
||||
comp.add("unet", spec2.load(), collection="sdxl")
|
||||
```
|
||||
|
||||
The manager automatically removes the old UNet and adds the new one:
|
||||
|
||||
```out
|
||||
ComponentsManager: removing existing unet from collection 'sdxl': unet_139917723891888
|
||||
'unet_139917723893136'
|
||||
```
|
||||
|
||||
Only one UNet remains in the collection:
|
||||
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
====================================================================================================================================================================
|
||||
Models:
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
unet_139917723893136 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | RunDiffusion/Juggernaut-XL-v9|unet|fp16|null | sdxl
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
For example, in node-based systems, you can mark all models loaded from one node with the same collection label, automatically replace models when user loads new checkpoints under same name, batch delete all models in a collection when a node is removed.
|
||||
|
||||
## Retrieving Components
|
||||
|
||||
The Components Manager provides several methods to retrieve registered components.
|
||||
|
||||
The `get_one()` method returns a single component and supports pattern matching for the `name` parameter. You can use:
|
||||
- exact matches like `comp.get_one(name="unet")`
|
||||
- wildcards like `comp.get_one(name="unet*")` for components starting with "unet"
|
||||
- exclusion patterns like `comp.get_one(name="!unet")` to exclude components named "unet"
|
||||
- OR patterns like `comp.get_one(name="unet|vae")` to match either "unet" OR "vae".
|
||||
|
||||
You can also filter by collection with `comp.get_one(name="unet", collection="sdxl")` or by load_id. If multiple components match, `get_one()` throws an error.
|
||||
|
||||
Another useful method is `get_components_by_names()`, which takes a list of names and returns a dictionary mapping names to components. This is particularly helpful with modular pipelines since they provide lists of required component names, and the returned dictionary can be directly passed to `pipeline.update_components()`.
|
||||
|
||||
```py
|
||||
# Get components by name list
|
||||
component_dict = comp.get_components_by_names(names=["text_encoder", "unet", "vae"])
|
||||
# Returns: {"text_encoder": component1, "unet": component2, "vae": component3}
|
||||
```
|
||||
|
||||
## Using Components Manager with Modular Pipelines
|
||||
|
||||
The Components Manager integrates seamlessly with Modular Pipelines. All you need to do is pass a Components Manager instance to `from_pretrained()` or `init_pipeline()` with an optional `collection` parameter:
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
comp = ComponentsManager()
|
||||
pipe = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test1")
|
||||
```
|
||||
|
||||
By default, modular pipelines don't load components immediately, so both the pipeline and Components Manager start empty:
|
||||
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
==================================================
|
||||
No components registered.
|
||||
==================================================
|
||||
```
|
||||
|
||||
When you load components on the pipeline, they are automatically registered in the Components Manager:
|
||||
|
||||
```py
|
||||
>>> pipe.load_components(names="unet")
|
||||
>>> comp
|
||||
Components:
|
||||
==============================================================================================================================================================
|
||||
Models:
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
unet_139917726686304 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
|
||||
--------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
Now let's load all default components and then create a second pipeline that reuses all components from the first one. We pass the same Components Manager to the second pipeline but with a different collection:
|
||||
|
||||
```py
|
||||
# Load all default components
|
||||
>>> pipe.load_default_components()`
|
||||
|
||||
# Create a second pipeline using the same Components Manager but with a different collection
|
||||
>>> pipe2 = ModularPipeline.from_pretrained("YiYiXu/modular-demo-auto", components_manager=comp, collection="test2")
|
||||
```
|
||||
|
||||
As mentioned earlier, `ModularPipeline` has a property `null_component_names` that returns a list of component names it needs to load. We can conveniently use this list with the `get_components_by_names` method on the Components Manager:
|
||||
|
||||
```py
|
||||
# Get the list of components that pipe2 needs to load
|
||||
>>> pipe2.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'image_encoder', 'unet', 'vae', 'scheduler', 'controlnet']
|
||||
|
||||
# Retrieve all required components from the Components Manager
|
||||
>>> comp_dict = comp.get_components_by_names(names=pipe2.null_component_names)
|
||||
|
||||
# Update the pipeline with the retrieved components
|
||||
>>> pipe2.update_components(**comp_dict)
|
||||
```
|
||||
|
||||
The warnings that follow are expected and indicate that the Components Manager is correctly identifying that these components already exist and will be reused rather than creating duplicates:
|
||||
|
||||
```
|
||||
ComponentsManager: component 'text_encoder' already exists as 'text_encoder_139917586016400'
|
||||
ComponentsManager: component 'text_encoder_2' already exists as 'text_encoder_2_139917699973424'
|
||||
ComponentsManager: component 'tokenizer' already exists as 'tokenizer_139917580599504'
|
||||
ComponentsManager: component 'tokenizer_2' already exists as 'tokenizer_2_139915763443904'
|
||||
ComponentsManager: component 'image_encoder' already exists as 'image_encoder_139917722468304'
|
||||
ComponentsManager: component 'unet' already exists as 'unet_139917580609632'
|
||||
ComponentsManager: component 'vae' already exists as 'vae_139917722459040'
|
||||
ComponentsManager: component 'scheduler' already exists as 'scheduler_139916266559408'
|
||||
ComponentsManager: component 'controlnet' already exists as 'controlnet_139917722454432'
|
||||
```
|
||||
```
|
||||
|
||||
The pipeline is now fully loaded:
|
||||
|
||||
```py
|
||||
# null_component_names return empty list, meaning everything are loaded
|
||||
>>> pipe2.null_component_names
|
||||
[]
|
||||
```
|
||||
|
||||
No new components were added to the Components Manager - we're reusing everything. All models are now associated with both `test1` and `test2` collections, showing that these components are shared across multiple pipelines:
|
||||
```py
|
||||
>>> comp
|
||||
Components:
|
||||
========================================================================================================================================================================================
|
||||
Models:
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
Name_ID | Class | Device: act(exec) | Dtype | Size (GB) | Load ID | Collection
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
text_encoder_139917586016400 | CLIPTextModel | cpu | torch.float32 | 0.46 | SG161222/RealVisXL_V4.0|text_encoder|null|null | test1
|
||||
| | | | | | test2
|
||||
text_encoder_2_139917699973424 | CLIPTextModelWithProjection | cpu | torch.float32 | 2.59 | SG161222/RealVisXL_V4.0|text_encoder_2|null|null | test1
|
||||
| | | | | | test2
|
||||
unet_139917580609632 | UNet2DConditionModel | cpu | torch.float32 | 9.56 | SG161222/RealVisXL_V4.0|unet|null|null | test1
|
||||
| | | | | | test2
|
||||
controlnet_139917722454432 | ControlNetModel | cpu | torch.float32 | 4.66 | diffusers/controlnet-canny-sdxl-1.0|null|null|null | test1
|
||||
| | | | | | test2
|
||||
vae_139917722459040 | AutoencoderKL | cpu | torch.float32 | 0.31 | SG161222/RealVisXL_V4.0|vae|null|null | test1
|
||||
| | | | | | test2
|
||||
image_encoder_139917722468304 | CLIPVisionModelWithProjection | cpu | torch.float32 | 6.87 | h94/IP-Adapter|sdxl_models/image_encoder|null|null | test1
|
||||
| | | | | | test2
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Other Components:
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
ID | Class | Collection
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
tokenizer_139917580599504 | CLIPTokenizer | test1
|
||||
| | test2
|
||||
scheduler_139916266559408 | EulerDiscreteScheduler | test1
|
||||
| | test2
|
||||
tokenizer_2_139915763443904 | CLIPTokenizer | test1
|
||||
| | test2
|
||||
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
```
|
||||
|
||||
|
||||
## Automatic Memory Management
|
||||
|
||||
The Components Manager provides a global offloading strategy across all models, regardless of which pipeline is using them:
|
||||
|
||||
```py
|
||||
comp.enable_auto_cpu_offload(device="cuda")
|
||||
```
|
||||
|
||||
When enabled, all models start on CPU. The manager moves models to the device right before they're used and moves other models back to CPU when GPU memory runs low. You can set your own rules for which models to offload first. This works smoothly as you add or remove components. Once it's on, you don't need to worry about device placement - you can focus on your workflow.
|
||||
|
||||
|
||||
|
||||
## Practical Example: Building Modular Workflows with Component Reuse
|
||||
|
||||
Now that we've covered the basics of the Components Manager, let's walk through a practical example that shows how to build workflows in a modular setting and use the Components Manager to reuse components across multiple pipelines. This example demonstrates the true power of Modular Diffusers by working with multiple pipelines that can share components.
|
||||
|
||||
In this example, we'll generate latents from a text-to-image pipeline, then refine them with an image-to-image pipeline. We will also use Lora and IP-Adapter.
|
||||
|
||||
Let's create a modular text-to-image workflow by separating it into three components: `text_blocks` for encoding prompts, `t2i_blocks` for generating latents, and `decoder_blocks` for creating final images.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS
|
||||
|
||||
# Create modular blocks and separate text encoding and decoding steps
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["text2img"])
|
||||
text_blocks = t2i_blocks.sub_blocks.pop("text_encoder")
|
||||
decoder_blocks = t2i_blocks.sub_blocks.pop("decode")
|
||||
```
|
||||
|
||||
Now we will convert them into runnalbe pipelines and set up the Components Manager with auto offloading and organize components under a "t2i" collection:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentsManager, ModularPipeline
|
||||
|
||||
# Set up Components Manager with auto offloading
|
||||
components = ComponentsManager()
|
||||
components.enable_auto_cpu_offload(device="cuda")
|
||||
|
||||
# Create pipelines and load components
|
||||
t2i_repo = "YiYiXu/modular-demo-auto"
|
||||
t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i")
|
||||
|
||||
text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components)
|
||||
decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components)
|
||||
t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components)
|
||||
```
|
||||
|
||||
Load all components into the Components Manager under the "t2i" collection:
|
||||
|
||||
```py
|
||||
# Load all components (including IP-Adapter and ControlNet for later use)
|
||||
t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
Now distribute the loaded components to each pipeline:
|
||||
|
||||
```py
|
||||
# Get VAE for decoder (using get_one since there's only one)
|
||||
vae = components.get_one(load_id="SG161222/RealVisXL_V4.0|vae|null|null")
|
||||
decoder_node.update_components(vae=vae)
|
||||
|
||||
# Get text components for text node (using get_components_by_names for multiple components)
|
||||
text_components = components.get_components_by_names(text_node.null_component_names)
|
||||
text_node.update_components(**text_components)
|
||||
|
||||
# Get remaining components for t2i pipeline
|
||||
t2i_components = components.get_components_by_names(t2i_pipe.null_component_names)
|
||||
t2i_pipe.update_components(**t2i_components)
|
||||
```
|
||||
|
||||
Now we can generate images using our modular workflow:
|
||||
|
||||
```py
|
||||
# Generate text embeddings
|
||||
prompt = "an astronaut"
|
||||
text_embeddings = text_node(prompt=prompt, output=["prompt_embeds","negative_prompt_embeds", "pooled_prompt_embeds", "negative_pooled_prompt_embeds"])
|
||||
|
||||
# Generate latents and decode to image
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
latents_t2i = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
|
||||
image = decoder_node(latents=latents_t2i, output="images")[0]
|
||||
image.save("modular_part2_t2i.png")
|
||||
```
|
||||
|
||||
Let's add a LoRA:
|
||||
|
||||
```py
|
||||
# Load LoRA weights - only the UNet gets the adapter
|
||||
>>> t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face")
|
||||
>>> components
|
||||
Components:
|
||||
============================================================================================================================================================
|
||||
...
|
||||
Additional Component Info:
|
||||
==================================================
|
||||
|
||||
unet:
|
||||
Adapters: ['toy_face']
|
||||
```
|
||||
|
||||
You can see that the Components Manager tracks adapters metadata for all models it manages, and in our case, only Unet has lora loaded. This means we can reuse existing text embeddings.
|
||||
|
||||
```py
|
||||
# Generate with LoRA (reusing existing text embeddings)
|
||||
generator = torch.Generator(device="cuda").manual_seed(0)
|
||||
latents_lora = t2i_pipe(**text_embeddings, num_inference_steps=25, generator=generator, output="latents")
|
||||
image = decoder_node(latents=latents_lora, output="images")[0]
|
||||
image.save("modular_part2_lora.png")
|
||||
```
|
||||
|
||||
|
||||
Now let's create a refiner pipeline that reuses components from our text-to-image workflow:
|
||||
|
||||
```py
|
||||
# Create refiner blocks (removing image_encoder and decode since we work with latents)
|
||||
refiner_blocks = SequentialPipelineBlocks.from_blocks_dict(ALL_BLOCKS["img2img"])
|
||||
refiner_blocks.sub_blocks.pop("image_encoder")
|
||||
refiner_blocks.sub_blocks.pop("decode")
|
||||
|
||||
# Create refiner pipeline with different repo and collection
|
||||
refiner_repo = "YiYiXu/modular_refiner"
|
||||
refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=components, collection="refiner")
|
||||
```
|
||||
|
||||
We pass the **same Components Manager** (`components`) to the refiner pipeline, but with a **different collection** (`"refiner"`). This allows the refiner to access and reuse components from the "t2i" collection while organizing its own components (like the refiner UNet) under the "refiner" collection.
|
||||
|
||||
```py
|
||||
# Load only the refiner UNet (different from t2i UNet)
|
||||
refiner_pipe.load_components(names="unet", torch_dtype=torch.float16)
|
||||
|
||||
# Reuse components from t2i pipeline using pattern matching
|
||||
reuse_components = components.search_components("text_encoder_2|scheduler|vae|tokenizer_2")
|
||||
refiner_pipe.update_components(**reuse_components)
|
||||
```
|
||||
|
||||
When we reuse components from the "t2i" collection, they automatically get added to the "refiner" collection as well. You can verify this by checking the Components Manager - you'll see components like `vae`, `scheduler`, etc. listed under both collections, indicating they're shared between workflows.
|
||||
|
||||
Now we can refine any of our generated latents:
|
||||
|
||||
```py
|
||||
# Refine all our different latents
|
||||
refined_latents = refiner_pipe(image_latents=latents_t2i, prompt=prompt, num_inference_steps=10, output="latents")
|
||||
refined_image = decoder_node(latents=refined_latents, output="images")[0]
|
||||
refined_image.save("modular_part2_t2i_refine_out.png")
|
||||
|
||||
refined_latents = refiner_pipe(image_latents=latents_lora, prompt=prompt, num_inference_steps=10, output="latents")
|
||||
refined_image = decoder_node(latents=refined_latents, output="images")[0]
|
||||
refined_image.save("modular_part2_lora_refine_out.png")
|
||||
```
|
||||
|
||||
|
||||
Here are the results from our modular pipeline examples.
|
||||
|
||||
#### Base Text-to-Image Generation
|
||||
| Base Text-to-Image | Base Text-to-Image (Refined) |
|
||||
|-------------------|------------------------------|
|
||||
|  |  |
|
||||
|
||||
#### LoRA
|
||||
| LoRA | LoRA (Refined) |
|
||||
|-------------------|------------------------------|
|
||||
|  |  |
|
||||
|
||||
648
docs/source/en/modular_diffusers/end_to_end_guide.md
Normal file
648
docs/source/en/modular_diffusers/end_to_end_guide.md
Normal file
@@ -0,0 +1,648 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# End-to-End Developer Guide: Building with Modular Diffusers
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
|
||||
|
||||
</Tip>
|
||||
|
||||
|
||||
In this tutorial we will walk through the process of adding a new pipeline to the modular framework using differential diffusion as our example. We'll cover the complete workflow from implementation to deployment: implementing the new pipeline, ensuring compatibility with existing tools, sharing the code on Hugging Face Hub, and deploying it as a UI node.
|
||||
|
||||
We'll also demonstrate the 4-step framework process we use for implementing new basic pipelines in the modular system.
|
||||
|
||||
1. **Start with an existing pipeline as a base**
|
||||
- Identify which existing pipeline is most similar to the one you want to implement
|
||||
- Determine what part of the pipeline needs modification
|
||||
|
||||
2. **Build a working pipeline structure first**
|
||||
- Assemble the complete pipeline structure
|
||||
- Use existing blocks wherever possible
|
||||
- For new blocks, create placeholders (e.g. you can copy from similar blocks and change the name) without implementing custom logic just yet
|
||||
|
||||
3. **Set up an example**
|
||||
- Create a simple inference script with expected inputs/outputs
|
||||
|
||||
4. **Implement your custom logic and test incrementally**
|
||||
- Add the custom logics the blocks you want to change
|
||||
- Test incrementally, and inspect pipeline states and debug as needed
|
||||
|
||||
Let's see how this works with the Differential Diffusion example.
|
||||
|
||||
|
||||
## Differential Diffusion Pipeline
|
||||
|
||||
### Start with an existing pipeline
|
||||
|
||||
Differential diffusion (https://differential-diffusion.github.io/) is an image-to-image workflow, so it makes sense for us to start with the preset of pipeline blocks used to build img2img pipeline (`IMAGE2IMAGE_BLOCKS`) and see how we can build this new pipeline with them.
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
>>> IMAGE2IMAGE_BLOCKS = InsertableDict([
|
||||
... ("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
... ("image_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
... ("input", StableDiffusionXLInputStep),
|
||||
... ("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
... ("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
... ("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
... ("denoise", StableDiffusionXLDenoiseStep),
|
||||
... ("decode", StableDiffusionXLDecodeStep)
|
||||
... ])
|
||||
```
|
||||
|
||||
Note that "denoise" (`StableDiffusionXLDenoiseStep`) is a `LoopSequentialPipelineBlocks` that contains 3 loop blocks (more on LoopSequentialPipelineBlocks [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#loopsequentialpipelineblocks))
|
||||
|
||||
```py
|
||||
>>> denoise_blocks = IMAGE2IMAGE_BLOCKS["denoise"]()
|
||||
>>> print(denoise_blocks)
|
||||
```
|
||||
|
||||
```out
|
||||
StableDiffusionXLDenoiseStep(
|
||||
Class: StableDiffusionXLDenoiseLoopWrapper
|
||||
|
||||
Description: Denoise step that iteratively denoise the latents.
|
||||
Its loop logic is defined in `StableDiffusionXLDenoiseLoopWrapper.__call__` method
|
||||
At each iteration, it runs blocks defined in `sub_blocks` sequencially:
|
||||
- `StableDiffusionXLLoopBeforeDenoiser`
|
||||
- `StableDiffusionXLLoopDenoiser`
|
||||
- `StableDiffusionXLLoopAfterDenoiser`
|
||||
This block supports both text2img and img2img tasks.
|
||||
|
||||
|
||||
Components:
|
||||
scheduler (`EulerDiscreteScheduler`)
|
||||
guider (`ClassifierFreeGuidance`)
|
||||
unet (`UNet2DConditionModel`)
|
||||
|
||||
Sub-Blocks:
|
||||
[0] before_denoiser (StableDiffusionXLLoopBeforeDenoiser)
|
||||
Description: step within the denoising loop that prepare the latent input for the denoiser. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
|
||||
|
||||
[1] denoiser (StableDiffusionXLLoopDenoiser)
|
||||
Description: Step within the denoising loop that denoise the latents with guidance. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
|
||||
|
||||
[2] after_denoiser (StableDiffusionXLLoopAfterDenoiser)
|
||||
Description: step within the denoising loop that update the latents. This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` object (e.g. `StableDiffusionXLDenoiseLoopWrapper`)
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Let's compare standard image-to-image and differential diffusion! The key difference in algorithm is that standard image-to-image diffusion applies uniform noise across all pixels based on a single `strength` parameter, but differential diffusion uses a change map where each pixel value determines when that region starts denoising. Regions with lower values get "frozen" earlier by replacing them with noised original latents, preserving more of the original image.
|
||||
|
||||
Therefore, the key differences when it comes to pipeline implementation would be:
|
||||
1. The `prepare_latents` step (which prepares the change map and pre-computes noised latents for all timesteps)
|
||||
2. The `denoise` step (which selectively applies denoising based on the change map)
|
||||
3. Since differential diffusion doesn't use the `strength` parameter, we'll use the text-to-image `set_timesteps` step instead of the image-to-image version
|
||||
|
||||
To implement differntial diffusion, we can reuse most blocks from image-to-image and text-to-image workflows, only modifying the `prepare_latents` step and the first part of the `denoise` step (i.e. `before_denoiser (StableDiffusionXLLoopBeforeDenoiser)`).
|
||||
|
||||
Here's a flowchart showing the pipeline structure and the changes we need to make:
|
||||
|
||||
|
||||

|
||||
|
||||
|
||||
### Build a Working Pipeline Structure
|
||||
|
||||
ok now we've identified the blocks to modify, let's build the pipeline skeleton first - at this stage, our goal is to get the pipeline struture working end-to-end (even though it's just doing the img2img behavior). I would simply create placeholder blocks by copying from existing ones:
|
||||
|
||||
```py
|
||||
>>> # Copy existing blocks as placeholders
|
||||
>>> class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
|
||||
... """Copied from StableDiffusionXLImg2ImgPrepareLatentsStep - will modify later"""
|
||||
... # ... same implementation as StableDiffusionXLImg2ImgPrepareLatentsStep
|
||||
...
|
||||
>>> class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
|
||||
... """Copied from StableDiffusionXLLoopBeforeDenoiser - will modify later"""
|
||||
... # ... same implementation as StableDiffusionXLLoopBeforeDenoiser
|
||||
```
|
||||
|
||||
`SDXLDiffDiffLoopBeforeDenoiser` is the be part of the denoise loop we need to change. Let's use it to assemble a `SDXLDiffDiffDenoiseStep`.
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLLoopDenoiser, StableDiffusionXLLoopAfterDenoiser]
|
||||
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
```
|
||||
|
||||
Now we can put together our differential diffusion pipeline.
|
||||
|
||||
```py
|
||||
>>> DIFFDIFF_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
>>> DIFFDIFF_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
>>> DIFFDIFF_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
>>> DIFFDIFF_BLOCKS["denoise"] = SDXLDiffDiffDenoiseStep
|
||||
>>>
|
||||
>>> dd_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_BLOCKS)
|
||||
>>> print(dd_blocks)
|
||||
>>> # At this point, the pipeline works exactly like img2img since our blocks are just copies
|
||||
```
|
||||
|
||||
### Set up an example
|
||||
|
||||
ok, so now our blocks should be able to compile without an error, we can move on to the next step. Let's setup a simple example so we can run the pipeline as we build it. diff-diff use same model checkpoints as SDXL so we can fetch the models from a regular SDXL repo.
|
||||
|
||||
```py
|
||||
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
>>> dd_pipeline.load_default_componenets(torch_dtype=torch.float16)
|
||||
>>> dd_pipeline.to("cuda")
|
||||
```
|
||||
|
||||
We will use this example script:
|
||||
|
||||
```py
|
||||
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
>>>
|
||||
>>> prompt = "a green pear"
|
||||
>>> negative_prompt = "blurry"
|
||||
>>>
|
||||
>>> image = dd_pipeline(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_inference_steps=25,
|
||||
... diffdiff_map=mask,
|
||||
... image=image,
|
||||
... output="images"
|
||||
... )[0]
|
||||
>>>
|
||||
>>> image.save("diffdiff_out.png")
|
||||
```
|
||||
|
||||
If you run the script right now, you will get a complaint about unexpected input `diffdiff_map`.
|
||||
and you would get the same result as the original img2img pipeline.
|
||||
|
||||
### implement your custom logic and test incrementally
|
||||
|
||||
Let's modify the pipeline so that we can get expected result with this example script.
|
||||
|
||||
We'll start with the `prepare_latents` step. The main changes are:
|
||||
- Requires a new user input `diffdiff_map`
|
||||
- Requires new component `mask_processor` to process the `diffdiff_map`
|
||||
- Requires new intermediate inputs:
|
||||
- Need `timestep` instead of `latent_timestep` to precompute all the latents
|
||||
- Need `num_inference_steps` to create the `diffdiff_masks`
|
||||
- create a new output `diffdiff_masks` and `original_latents`
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 use `print(dd_pipeline.doc)` to check compiled inputs and outputs of the built piepline.
|
||||
|
||||
e.g. after we added `diffdiff_map` as an input in this step, we can run `print(dd_pipeline.doc)` to verify that it shows up in the docstring as a user input.
|
||||
|
||||
</Tip>
|
||||
|
||||
Once we make sure all the variables we need are available in the block state, we can implement the diff-diff logic inside `__call__`. We created 2 new variables: the change map `diffdiff_mask` and the pre-computed noised latents for all timesteps `original_latents`.
|
||||
|
||||
<Tip>
|
||||
|
||||
💡 Implement incrementally! Run the example script as you go, and insert `print(state)` and `print(block_state)` everywhere inside the `__call__` method to inspect the intermediate results. This helps you understand what's going on and what each line you just added does.
|
||||
|
||||
</Tip>
|
||||
|
||||
Here are the key changes we made to implement differential diffusion:
|
||||
|
||||
**1. Modified `prepare_latents` step:**
|
||||
```diff
|
||||
class SDXLDiffDiffPrepareLatentsStep(PipelineBlock):
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
return [
|
||||
ComponentSpec("vae", AutoencoderKL),
|
||||
ComponentSpec("scheduler", EulerDiscreteScheduler),
|
||||
+ ComponentSpec("mask_processor", VaeImageProcessor, config=FrozenDict({"do_normalize": False, "do_convert_grayscale": True}))
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
+ InputParam("diffdiff_map", required=True),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[InputParam]:
|
||||
return [
|
||||
InputParam("generator"),
|
||||
- InputParam("latent_timestep", required=True, type_hint=torch.Tensor),
|
||||
+ InputParam("timesteps", type_hint=torch.Tensor),
|
||||
+ InputParam("num_inference_steps", type_hint=int),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self) -> List[OutputParam]:
|
||||
return [
|
||||
+ OutputParam("original_latents", type_hint=torch.Tensor),
|
||||
+ OutputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
def __call__(self, components, state: PipelineState):
|
||||
# ... existing logic ...
|
||||
+ # Process change map and create masks
|
||||
+ diffdiff_map = components.mask_processor.preprocess(block_state.diffdiff_map, height=latent_height, width=latent_width)
|
||||
+ thresholds = torch.arange(block_state.num_inference_steps, dtype=diffdiff_map.dtype) / block_state.num_inference_steps
|
||||
+ block_state.diffdiff_masks = diffdiff_map > (thresholds + (block_state.denoising_start or 0))
|
||||
+ block_state.original_latents = block_state.latents
|
||||
```
|
||||
|
||||
**2. Modified `before_denoiser` step:**
|
||||
```diff
|
||||
class SDXLDiffDiffLoopBeforeDenoiser(PipelineBlock):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Step within the denoising loop for differential diffusion that prepare the latent input for the denoiser"
|
||||
)
|
||||
|
||||
@property
|
||||
def inputs(self) -> List[Tuple[str, Any]]:
|
||||
return [
|
||||
InputParam("denoising_start"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self) -> List[str]:
|
||||
return [
|
||||
InputParam("latents", required=True, type_hint=torch.Tensor),
|
||||
InputParam("original_latents", type_hint=torch.Tensor),
|
||||
InputParam("diffdiff_masks", type_hint=torch.Tensor),
|
||||
]
|
||||
|
||||
def __call__(self, components, block_state, i, t):
|
||||
# Apply differential diffusion logic
|
||||
if i == 0 and block_state.denoising_start is None:
|
||||
block_state.latents = block_state.original_latents[:1]
|
||||
else:
|
||||
block_state.mask = block_state.diffdiff_masks[i].unsqueeze(0).unsqueeze(1)
|
||||
block_state.latents = block_state.original_latents[i] * block_state.mask + block_state.latents * (1 - block_state.mask)
|
||||
|
||||
# ... rest of existing logic ...
|
||||
```
|
||||
|
||||
That's all there is to it! We've just created a simple sequential pipeline by mix-and-match some existing and new pipeline blocks.
|
||||
|
||||
Now we use the process we've prepred in step2 to build the pipeline and inspect it.
|
||||
|
||||
|
||||
```py
|
||||
>> dd_pipeline
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
Description:
|
||||
|
||||
|
||||
Components:
|
||||
text_encoder (`CLIPTextModel`)
|
||||
text_encoder_2 (`CLIPTextModelWithProjection`)
|
||||
tokenizer (`CLIPTokenizer`)
|
||||
tokenizer_2 (`CLIPTokenizer`)
|
||||
guider (`ClassifierFreeGuidance`)
|
||||
vae (`AutoencoderKL`)
|
||||
image_processor (`VaeImageProcessor`)
|
||||
scheduler (`EulerDiscreteScheduler`)
|
||||
mask_processor (`VaeImageProcessor`)
|
||||
unet (`UNet2DConditionModel`)
|
||||
|
||||
Configs:
|
||||
force_zeros_for_empty_prompt (default: True)
|
||||
requires_aesthetics_score (default: False)
|
||||
|
||||
Blocks:
|
||||
[0] text_encoder (StableDiffusionXLTextEncoderStep)
|
||||
Description: Text Encoder step that generate text_embeddings to guide the image generation
|
||||
|
||||
[1] image_encoder (StableDiffusionXLVaeEncoderStep)
|
||||
Description: Vae Encoder step that encode the input image into a latent representation
|
||||
|
||||
[2] input (StableDiffusionXLInputStep)
|
||||
Description: Input processing step that:
|
||||
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
|
||||
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
|
||||
|
||||
All input tensors are expected to have either batch_size=1 or match the batch_size
|
||||
of prompt_embeds. The tensors will be duplicated across the batch dimension to
|
||||
have a final batch_size of batch_size * num_images_per_prompt.
|
||||
|
||||
[3] set_timesteps (StableDiffusionXLSetTimestepsStep)
|
||||
Description: Step that sets the scheduler's timesteps for inference
|
||||
|
||||
[4] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
|
||||
Description: Step that prepares the latents for the differential diffusion generation process
|
||||
|
||||
[5] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
|
||||
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
|
||||
|
||||
[6] denoise (SDXLDiffDiffDenoiseStep)
|
||||
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
|
||||
|
||||
[7] decode (StableDiffusionXLDecodeStep)
|
||||
Description: Step that decodes the denoised latents into images
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Run the example now, you should see an apple with its right half transformed into a green pear.
|
||||
|
||||

|
||||
|
||||
|
||||
## Adding IP-adapter
|
||||
|
||||
We provide an auto IP-adapter block that you can plug-and-play into your modular workflow. It's an `AutoPipelineBlocks`, so it will only run when the user passes an IP adapter image. In this tutorial, we'll focus on how to package it into your differential diffusion workflow. To learn more about `AutoPipelineBlocks`, see [here](https://huggingface.co/docs/diffusers/modular_diffusers/write_own_pipeline_block#autopipelineblocks)
|
||||
|
||||
We talked about how to add IP-adapter into your workflow in the [getting-started guide](https://huggingface.co/docs/diffusers/modular_diffusers/quicktour#ip-adapter). Let's just go ahead to create the IP-adapter block.
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl.encoders import StableDiffusionXLAutoIPAdapterStep
|
||||
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
```
|
||||
|
||||
We can directly add the ip-adapter block instance to the `diffdiff_blocks` that we created before. The `sub_blocks` attribute is a `InsertableDict`, so we're able to insert the it at specific position (index `0` here).
|
||||
|
||||
```py
|
||||
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
```
|
||||
|
||||
Take a look at the new diff-diff pipeline with ip-adapter!
|
||||
|
||||
```py
|
||||
>>> print(dd_blocks)
|
||||
```
|
||||
|
||||
The pipeline now lists ip-adapter as its first block, and tells you that it will run only if `ip_adapter_image` is provided. It also includes the two new components from ip-adpater: `image_encoder` and `feature_extractor`
|
||||
|
||||
```out
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: {'ip_adapter_image'}
|
||||
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('ip_adapter_image')`).
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description:
|
||||
|
||||
|
||||
Components:
|
||||
image_encoder (`CLIPVisionModelWithProjection`)
|
||||
feature_extractor (`CLIPImageProcessor`)
|
||||
unet (`UNet2DConditionModel`)
|
||||
guider (`ClassifierFreeGuidance`)
|
||||
text_encoder (`CLIPTextModel`)
|
||||
text_encoder_2 (`CLIPTextModelWithProjection`)
|
||||
tokenizer (`CLIPTokenizer`)
|
||||
tokenizer_2 (`CLIPTokenizer`)
|
||||
vae (`AutoencoderKL`)
|
||||
image_processor (`VaeImageProcessor`)
|
||||
scheduler (`EulerDiscreteScheduler`)
|
||||
mask_processor (`VaeImageProcessor`)
|
||||
|
||||
Configs:
|
||||
force_zeros_for_empty_prompt (default: True)
|
||||
requires_aesthetics_score (default: False)
|
||||
|
||||
Blocks:
|
||||
[0] ip_adapter (StableDiffusionXLAutoIPAdapterStep)
|
||||
Description: Run IP Adapter step if `ip_adapter_image` is provided.
|
||||
|
||||
[1] text_encoder (StableDiffusionXLTextEncoderStep)
|
||||
Description: Text Encoder step that generate text_embeddings to guide the image generation
|
||||
|
||||
[2] image_encoder (StableDiffusionXLVaeEncoderStep)
|
||||
Description: Vae Encoder step that encode the input image into a latent representation
|
||||
|
||||
[3] input (StableDiffusionXLInputStep)
|
||||
Description: Input processing step that:
|
||||
1. Determines `batch_size` and `dtype` based on `prompt_embeds`
|
||||
2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`
|
||||
|
||||
All input tensors are expected to have either batch_size=1 or match the batch_size
|
||||
of prompt_embeds. The tensors will be duplicated across the batch dimension to
|
||||
have a final batch_size of batch_size * num_images_per_prompt.
|
||||
|
||||
[4] set_timesteps (StableDiffusionXLSetTimestepsStep)
|
||||
Description: Step that sets the scheduler's timesteps for inference
|
||||
|
||||
[5] prepare_latents (SDXLDiffDiffPrepareLatentsStep)
|
||||
Description: Step that prepares the latents for the differential diffusion generation process
|
||||
|
||||
[6] prepare_add_cond (StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep)
|
||||
Description: Step that prepares the additional conditioning for the image-to-image/inpainting generation process
|
||||
|
||||
[7] denoise (SDXLDiffDiffDenoiseStep)
|
||||
Description: Pipeline block that iteratively denoise the latents over `timesteps`. The specific steps with each iteration can be customized with `sub_blocks` attributes
|
||||
|
||||
[8] decode (StableDiffusionXLDecodeStep)
|
||||
Description: Step that decodes the denoised latents into images
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Let's test it out. We used an orange image to condition the generation via ip-addapter and we can see a slight orange color and texture in the final output.
|
||||
|
||||
|
||||
```py
|
||||
>>> ip_adapter_block = StableDiffusionXLAutoIPAdapterStep()
|
||||
>>> dd_blocks.sub_blocks.insert("ip_adapter", ip_adapter_block, 0)
|
||||
>>>
|
||||
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
>>> dd_pipeline.loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
|
||||
>>> dd_pipeline.loader.set_ip_adapter_scale(0.6)
|
||||
>>> dd_pipeline = dd_pipeline.to(device)
|
||||
>>>
|
||||
>>> ip_adapter_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_orange.jpeg")
|
||||
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
>>>
|
||||
>>> prompt = "a green pear"
|
||||
>>> negative_prompt = "blurry"
|
||||
>>> generator = torch.Generator(device=device).manual_seed(42)
|
||||
>>>
|
||||
>>> image = dd_pipeline(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_inference_steps=25,
|
||||
... generator=generator,
|
||||
... ip_adapter_image=ip_adapter_image,
|
||||
... diffdiff_map=mask,
|
||||
... image=image,
|
||||
... output="images"
|
||||
... )[0]
|
||||
```
|
||||
|
||||
## Working with ControlNets
|
||||
|
||||
What about controlnet? Can differential diffusion work with controlnet? The key differences between a regular pipeline and a ControlNet pipeline are:
|
||||
1. A ControlNet input step that prepares the control condition
|
||||
2. Inside the denoising loop, a modified denoiser step where the control image is first processed through ControlNet, then control information is injected into the UNet
|
||||
|
||||
From looking at the code workflow: differential diffusion only modifies the "before denoiser" step, while ControlNet operates within the "denoiser" itself. Since they intervene at different points in the pipeline, they should work together without conflicts.
|
||||
|
||||
Intuitively, these two techniques are orthogonal and should combine naturally: differential diffusion controls how much the inference process can deviate from the original in each region, while ControlNet controls in what direction that change occurs.
|
||||
|
||||
With this understanding, let's assemble the `SDXLDiffDiffControlNetDenoiseStep`:
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffControlNetDenoiseStep(StableDiffusionXLDenoiseLoopWrapper):
|
||||
... block_classes = [SDXLDiffDiffLoopBeforeDenoiser, StableDiffusionXLControlNetLoopDenoiser, StableDiffusionXLDenoiseLoopAfterDenoiser]
|
||||
... block_names = ["before_denoiser", "denoiser", "after_denoiser"]
|
||||
>>>
|
||||
>>> controlnet_denoise_block = SDXLDiffDiffControlNetDenoiseStep()
|
||||
>>> # print(controlnet_denoise)
|
||||
```
|
||||
|
||||
We provide a auto controlnet input block that you can directly put into your workflow to proceess the `control_image`: similar to auto ip-adapter block, this step will only run if `control_image` input is passed from user. It work with both controlnet and controlnet union.
|
||||
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep
|
||||
>>> control_input_block = StableDiffusionXLAutoControlNetInputStep()
|
||||
>>> print(control_input_block)
|
||||
```
|
||||
|
||||
```out
|
||||
StableDiffusionXLAutoControlNetInputStep(
|
||||
Class: AutoPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: ['control_image', 'control_mode']
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description: Controlnet Input step that prepare the controlnet input.
|
||||
This is an auto pipeline block that works for both controlnet and controlnet_union.
|
||||
(it should be called right before the denoise step) - `StableDiffusionXLControlNetUnionInputStep` is called to prepare the controlnet input when `control_mode` and `control_image` are provided.
|
||||
- `StableDiffusionXLControlNetInputStep` is called to prepare the controlnet input when `control_image` is provided. - if neither `control_mode` nor `control_image` is provided, step will be skipped.
|
||||
|
||||
|
||||
Components:
|
||||
controlnet (`ControlNetUnionModel`)
|
||||
control_image_processor (`VaeImageProcessor`)
|
||||
|
||||
Sub-Blocks:
|
||||
• controlnet_union [trigger: control_mode] (StableDiffusionXLControlNetUnionInputStep)
|
||||
Description: step that prepares inputs for the ControlNetUnion model
|
||||
|
||||
• controlnet [trigger: control_image] (StableDiffusionXLControlNetInputStep)
|
||||
Description: step that prepare inputs for controlnet
|
||||
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
Let's assemble the blocks and run an example using controlnet + differential diffusion. We used a tomato as `control_image`, so you can see that in the output, the right half that transformed into a pear had a tomato-like shape.
|
||||
|
||||
```py
|
||||
>>> dd_blocks.sub_blocks.insert("controlnet_input", control_input_block, 7)
|
||||
>>> dd_blocks.sub_blocks["denoise"] = controlnet_denoise_block
|
||||
>>>
|
||||
>>> dd_pipeline = dd_blocks.init_pipeline("YiYiXu/modular-demo-auto", collection="diffdiff")
|
||||
>>> dd_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
>>> dd_pipeline = dd_pipeline.to(device)
|
||||
>>>
|
||||
>>> control_image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/diffdiff_tomato_canny.jpeg")
|
||||
>>> image = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/20240329211129_4024911930.png?download=true")
|
||||
>>> mask = load_image("https://huggingface.co/datasets/OzzyGT/testing-resources/resolve/main/differential/gradient_mask.png?download=true")
|
||||
>>>
|
||||
>>> prompt = "a green pear"
|
||||
>>> negative_prompt = "blurry"
|
||||
>>> generator = torch.Generator(device=device).manual_seed(42)
|
||||
>>>
|
||||
>>> image = dd_pipeline(
|
||||
... prompt=prompt,
|
||||
... negative_prompt=negative_prompt,
|
||||
... num_inference_steps=25,
|
||||
... generator=generator,
|
||||
... control_image=control_image,
|
||||
... controlnet_conditioning_scale=0.5,
|
||||
... diffdiff_map=mask,
|
||||
... image=image,
|
||||
... output="images"
|
||||
... )[0]
|
||||
```
|
||||
|
||||
Optionally, We can combine `SDXLDiffDiffControlNetDenoiseStep` and `SDXLDiffDiffDenoiseStep` into a `AutoPipelineBlocks` so that same workflow can work with or without controlnet.
|
||||
|
||||
|
||||
```py
|
||||
>>> class SDXLDiffDiffAutoDenoiseStep(AutoPipelineBlocks):
|
||||
... block_classes = [SDXLDiffDiffControlNetDenoiseStep, SDXLDiffDiffDenoiseStep]
|
||||
... block_names = ["controlnet_denoise", "denoise"]
|
||||
... block_trigger_inputs = ["controlnet_cond", None]
|
||||
```
|
||||
|
||||
`SDXLDiffDiffAutoDenoiseStep` will run the ControlNet denoise step if `control_image` input is provided, otherwise it will run the regular denoise step.
|
||||
|
||||
<Tip>
|
||||
|
||||
Note that it's perfectly fine not to use `AutoPipelineBlocks`. In fact, we recommend only using `AutoPipelineBlocks` to package your workflow at the end once you've verified all your pipelines work as expected.
|
||||
|
||||
</Tip>
|
||||
|
||||
Now you can create the differential diffusion preset that works with ip-adapter & controlnet.
|
||||
|
||||
```py
|
||||
>>> DIFFDIFF_AUTO_BLOCKS = IMAGE2IMAGE_BLOCKS.copy()
|
||||
>>> DIFFDIFF_AUTO_BLOCKS["prepare_latents"] = SDXLDiffDiffPrepareLatentsStep
|
||||
>>> DIFFDIFF_AUTO_BLOCKS["set_timesteps"] = TEXT2IMAGE_BLOCKS["set_timesteps"]
|
||||
>>> DIFFDIFF_AUTO_BLOCKS["denoise"] = SDXLDiffDiffAutoDenoiseStep
|
||||
>>> DIFFDIFF_AUTO_BLOCKS.insert("ip_adapter", StableDiffusionXLAutoIPAdapterStep, 0)
|
||||
>>> DIFFDIFF_AUTO_BLOCKS.insert("controlnet_input",StableDiffusionXLControlNetAutoInput, 7)
|
||||
>>>
|
||||
>>> print(DIFFDIFF_AUTO_BLOCKS)
|
||||
```
|
||||
|
||||
to use
|
||||
|
||||
```py
|
||||
>>> dd_auto_blocks = SequentialPipelineBlocks.from_blocks_dict(DIFFDIFF_AUTO_BLOCKS)
|
||||
>>> dd_pipeline = dd_auto_blocks.init_pipeline(...)
|
||||
```
|
||||
## Creating a Modular Repo
|
||||
|
||||
You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff
|
||||
|
||||
To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily.
|
||||
|
||||
```py
|
||||
dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True)
|
||||
```
|
||||
|
||||
With a modular repo, it is very easy for the community to use the workflow you just created! Here is an example to use the differential-diffusion pipeline we just created and shared.
|
||||
|
||||
```py
|
||||
>>> from diffusers.modular_pipelines import ModularPipeline, ComponentsManager
|
||||
>>> import torch
|
||||
>>> from diffusers.utils import load_image
|
||||
>>>
|
||||
>>> repo_id = "YiYiXu/modular-diffdiff-0704"
|
||||
>>>
|
||||
>>> components = ComponentsManager()
|
||||
>>>
|
||||
>>> diffdiff_pipeline = ModularPipeline.from_pretrained(repo_id, trust_remote_code=True, components_manager=components, collection="diffdiff")
|
||||
>>> diffdiff_pipeline.load_default_components(torch_dtype=torch.float16)
|
||||
>>> components.enable_auto_cpu_offload()
|
||||
```
|
||||
|
||||
see more usage example on model card
|
||||
|
||||
## deploy a mellon node
|
||||
|
||||
[YIYI TODO: for now, here is an example of mellon node https://huggingface.co/YiYiXu/diff-diff-mellon]
|
||||
1210
docs/source/en/modular_diffusers/getting_started.md
Normal file
1210
docs/source/en/modular_diffusers/getting_started.md
Normal file
File diff suppressed because it is too large
Load Diff
817
docs/source/en/modular_diffusers/write_own_pipeline_block.md
Normal file
817
docs/source/en/modular_diffusers/write_own_pipeline_block.md
Normal file
@@ -0,0 +1,817 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Writing Your Own Pipeline Blocks
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
🧪 **Experimental Feature**: Modular Diffusers is an experimental feature we are actively developing. The API may be subject to breaking changes.
|
||||
|
||||
</Tip>
|
||||
|
||||
In Modular Diffusers, you build your workflow using `ModularPipelineBlocks`. We support 4 different types of blocks: `PipelineBlock`, `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`. Among them, `PipelineBlock` is the most fundamental building block of the whole system - it's like a brick in a Lego system. These blocks are designed to easily connect with each other, allowing for modular construction of creative and potentially very complex workflows.
|
||||
|
||||
In this tutorial, we will focus on how to write a basic `PipelineBlock` and how it interacts with other components in the system. We will also cover how to connect them together using the multi-blocks: `SequentialPipelineBlocks`, `LoopSequentialPipelineBlocks`, and `AutoPipelineBlocks`.
|
||||
|
||||
|
||||
## Understanding the Foundation: `PipelineState`
|
||||
|
||||
Before we dive into creating `PipelineBlock`s, we need to have a basic understanding of `PipelineState` - the core data structure that all blocks operate on. This concept is fundamental to understanding how blocks interact with each other and the pipeline system.
|
||||
|
||||
In the modular diffusers system, `PipelineState` acts as the global state container that `PipelineBlock`s operate on - each block gets a local view (`BlockState`) of the relevant variables it needs from `PipelineState`, performs its operations, and then updates `PipelineState` with any changes.
|
||||
|
||||
While `PipelineState` maintains the complete runtime state of the pipeline, `PipelineBlock`s define what parts of that state they can read from and write to through their `input`s, `intermediates_inputs`, and `intermediates_outputs` properties.
|
||||
|
||||
A `PipelineState` consists of two distinct states:
|
||||
- The **immutable state** (i.e. the `inputs` dict) contains a copy of values provided by users. Once a value is added to the immutable state, it cannot be changed. Blocks can read from the immutable state but cannot write to it.
|
||||
- The **mutable state** (i.e. the `intermediates` dict) contains variables that are passed between blocks and can be modified by them.
|
||||
|
||||
Here's an example of what a `PipelineState` looks like:
|
||||
|
||||
```
|
||||
PipelineState(
|
||||
inputs={
|
||||
prompt: 'a cat'
|
||||
guidance_scale: 7.0
|
||||
num_inference_steps: 25
|
||||
},
|
||||
intermediates={
|
||||
prompt_embeds: Tensor(dtype=torch.float32, shape=torch.Size([1, 1, 1, 1]))
|
||||
negative_prompt_embeds: None
|
||||
},
|
||||
```
|
||||
|
||||
## Creating a `PipelineBlock`
|
||||
|
||||
To write a `PipelineBlock` class, you need to define a few properties that determine how your block interacts with the pipeline state. Understanding these properties is crucial - they define what data your block can access and what it can produce.
|
||||
|
||||
The three main properties you need to define are:
|
||||
- `inputs`: Immutable values from the user that cannot be modified
|
||||
- `intermediate_inputs`: Mutable values from previous blocks that can be read and modified
|
||||
- `intermediate_outputs`: New values your block creates for subsequent blocks
|
||||
|
||||
Let's explore each one and understand how they work with the pipeline state.
|
||||
|
||||
**Inputs: Immutable User Values**
|
||||
|
||||
Inputs are variables your block needs from the immutable pipeline state - these are user-provided values that cannot be modified by any block. You define them using `InputParam`:
|
||||
|
||||
```py
|
||||
user_inputs = [
|
||||
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
|
||||
]
|
||||
```
|
||||
|
||||
When you list something as an input, you're saying "I need this value directly from the end user, and I will talk to them directly, telling them what I need in the 'description' field. They will provide it and it will come to me unchanged."
|
||||
|
||||
This is especially useful for raw values that serve as the "source of truth" in your workflow. For example, with a raw image, many workflows require preprocessing steps like resizing that a previous block might have performed. But in many cases, you also want the raw PIL image. In some inpainting workflows, you need the original image to overlay with the generated result for better control and consistency.
|
||||
|
||||
**Intermediate Inputs: Mutable Values from Previous Blocks**
|
||||
|
||||
Intermediate inputs are variables your block needs from the mutable pipeline state - these are values that can be read and modified. They're typically created by previous blocks, but could also be directly provided by the user if not the case:
|
||||
|
||||
```py
|
||||
user_intermediate_inputs = [
|
||||
InputParam(name="processed_image", type_hint="torch.Tensor", description="image that has been preprocessed and normalized"),
|
||||
]
|
||||
```
|
||||
|
||||
When you list something as an intermediate input, you're saying "I need this value, but I want to work with a different block that has already created it. I already know for sure that I can get it from this other block, but it's okay if other developers want use something different."
|
||||
|
||||
**Intermediate Outputs: New Values for Subsequent Blocks**
|
||||
|
||||
Intermediate outputs are new variables your block creates and adds to the mutable pipeline state so they can be used by subsequent blocks:
|
||||
|
||||
```py
|
||||
user_intermediate_outputs = [
|
||||
OutputParam(name="image_latents", description="latents representing the image")
|
||||
]
|
||||
```
|
||||
|
||||
Intermediate inputs and intermediate outputs work together like Lego studs and anti-studs - they're the connection points that make blocks modular. When one block produces an intermediate output, it becomes available as an intermediate input for subsequent blocks. This is where the "modular" nature of the system really shines - blocks can be connected and reconnected in different ways as long as their inputs and outputs match. We will see more how they connect when we talk about multi-blocks.
|
||||
|
||||
**The `__call__` Method Structure**
|
||||
|
||||
Your `PipelineBlock`'s `__call__` method should follow this structure:
|
||||
|
||||
```py
|
||||
def __call__(self, components, state):
|
||||
# Get a local view of the state variables this block needs
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs and intermediate_inputs
|
||||
# You can access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
The `block_state` object contains all the variables you defined in `inputs` and `intermediate_inputs`, making them easily accessible for your computation.
|
||||
|
||||
**Components and Configs**
|
||||
|
||||
You can define the components and pipeline-level configs your block needs using `ComponentSpec` and `ConfigSpec`:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ConfigSpec
|
||||
|
||||
# Define components your block needs
|
||||
expected_components = [
|
||||
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
|
||||
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
|
||||
]
|
||||
|
||||
# Define pipeline-level configs
|
||||
expected_config = [
|
||||
ConfigSpec("force_zeros_for_empty_prompt", True)
|
||||
]
|
||||
```
|
||||
|
||||
**Components**: In the `ComponentSpec`, You must provide a `name` and ideally a `type_hint`. The actual loading details (`repo`, `subfolder`, `variant` and `revision` fields) are typically specified when creating the pipeline, as we covered in the [Getting Started Guide](https://huggingface.co/docs/diffusers/en/modular_diffusers/getting_started#loading-components-into-a-modularpipeline).
|
||||
|
||||
**Configs**: Simple pipeline-level settings that control behavior across all blocks.
|
||||
|
||||
When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the pipeline collects all component requirements from the blocks and fetches the loading specs from the modular repository. The components are then made available to your block in the `components` argument of the `__call__` method.
|
||||
|
||||
That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks
|
||||
|
||||
Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that.
|
||||
|
||||
**Helper Function**
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import PipelineBlock, InputParam, OutputParam
|
||||
import torch
|
||||
|
||||
def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block_fn=None, description=None):
|
||||
class TestBlock(PipelineBlock):
|
||||
model_name = "test"
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return inputs
|
||||
|
||||
@property
|
||||
def intermediate_inputs(self):
|
||||
return intermediate_inputs
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return intermediate_outputs
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return description if description is not None else ""
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
if block_fn is not None:
|
||||
block_state = block_fn(block_state, state)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
|
||||
return TestBlock
|
||||
```
|
||||
|
||||
|
||||
Let's create a simple block to see how these definitions interact with the pipeline state. To better understand what's happening, we'll print out the states before and after updates to inspect them:
|
||||
|
||||
```py
|
||||
inputs = [
|
||||
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
|
||||
]
|
||||
|
||||
intermediate_inputs = [InputParam(name="batch_size", type_hint=int)]
|
||||
|
||||
intermediate_outputs = [
|
||||
OutputParam(name="image_latents", description="latents representing the image")
|
||||
]
|
||||
|
||||
def image_encoder_block_fn(block_state, pipeline_state):
|
||||
print(f"pipeline_state (before update): {pipeline_state}")
|
||||
print(f"block_state (before update): {block_state}")
|
||||
|
||||
# Simulate processing the image
|
||||
block_state.image = torch.randn(1, 3, 512, 512)
|
||||
block_state.batch_size = block_state.batch_size * 2
|
||||
block_state.processed_image = [torch.randn(1, 3, 512, 512)] * block_state.batch_size
|
||||
block_state.image_latents = torch.randn(1, 4, 64, 64)
|
||||
|
||||
print(f"block_state (after update): {block_state}")
|
||||
return block_state
|
||||
|
||||
# Create a block with our definitions
|
||||
image_encoder_block_cls = make_block(
|
||||
inputs=inputs,
|
||||
intermediate_inputs=intermediate_inputs,
|
||||
intermediate_outputs=intermediate_outputs,
|
||||
block_fn=image_encoder_block_fn,
|
||||
description=" Encode raw image into its latent presentation"
|
||||
)
|
||||
image_encoder_block = image_encoder_block_cls()
|
||||
pipe = image_encoder_block.init_pipeline()
|
||||
```
|
||||
|
||||
Let's check the pipeline's docstring to see what inputs it expects:
|
||||
```py
|
||||
>>> print(pipe.doc)
|
||||
class TestBlock
|
||||
|
||||
Encode raw image into its latent presentation
|
||||
|
||||
Inputs:
|
||||
|
||||
image (`PIL.Image`, *optional*):
|
||||
raw input image to process
|
||||
|
||||
batch_size (`int`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
image_latents (`None`):
|
||||
latents representing the image
|
||||
```
|
||||
|
||||
Notice that `batch_size` appears as an input even though we defined it as an intermediate input. This happens because no previous block provided it, so the pipeline makes it available as a user input. However, unlike regular inputs, this value goes directly into the mutable intermediate state.
|
||||
|
||||
Now let's run the pipeline:
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/image_of_squirrel_painting.png")
|
||||
state = pipe(image=image, batch_size=2)
|
||||
print(f"pipeline_state (after update): {state}")
|
||||
```
|
||||
```out
|
||||
pipeline_state (before update): PipelineState(
|
||||
inputs={
|
||||
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
|
||||
},
|
||||
intermediates={
|
||||
batch_size: 2
|
||||
},
|
||||
)
|
||||
block_state (before update): BlockState(
|
||||
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494640>
|
||||
batch_size: 2
|
||||
)
|
||||
|
||||
block_state (after update): BlockState(
|
||||
image: Tensor(dtype=torch.float32, shape=torch.Size([1, 3, 512, 512]))
|
||||
batch_size: 4
|
||||
processed_image: List[4] of Tensors with shapes [torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512]), torch.Size([1, 3, 512, 512])]
|
||||
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
|
||||
)
|
||||
pipeline_state (after update): PipelineState(
|
||||
inputs={
|
||||
image: <PIL.Image.Image image mode=RGB size=512x512 at 0x7F3ECC494550>
|
||||
},
|
||||
intermediates={
|
||||
batch_size: 4
|
||||
image_latents: Tensor(dtype=torch.float32, shape=torch.Size([1, 4, 64, 64]))
|
||||
},
|
||||
)
|
||||
```
|
||||
**Key Observations:**
|
||||
|
||||
1. **Before the update**: `image` (the input) goes to the immutable inputs dict, while `batch_size` (the intermediate_input) goes to the mutable intermediates dict, and both are available in `block_state`.
|
||||
|
||||
2. **After the update**:
|
||||
- **`image` (inputs)** changed in `block_state` but not in `pipeline_state` - this change is local to the block only.
|
||||
- **`batch_size (intermediate_inputs)`** was updated in both `block_state` and `pipeline_state` - this change affects subsequent blocks (we didn't need to declare it as an intermediate output since it was already in the intermediates dict)
|
||||
- **`image_latents (intermediate_outputs)`** was added to `pipeline_state` because it was declared as an intermediate output
|
||||
- **`processed_image`** was not added to `pipeline_state` because it wasn't declared as an intermediate output
|
||||
|
||||
I hope by now you have a basic idea about how `PipelineBlock` manages state through inputs, intermediate inputs, and intermediate outputs. The real power comes when we connect multiple blocks together - their intermediate outputs become intermediate inputs for subsequent blocks, creating modular workflows. Let's explore how to build these connections using multi-blocks like `SequentialPipelineBlocks`.
|
||||
|
||||
## Create a `SequentialPipelineBlocks`
|
||||
|
||||
I assume that you're already familiar with `SequentialPipelineBlocks` and how to create them with the `from_blocks_dict` API. It's one of the most common ways to use Modular Diffusers, and we've covered it pretty well in the [Getting Started Guide](https://huggingface.co/docs/diffusers/pr_9672/en/modular_diffusers/getting_started#modularpipelineblocks).
|
||||
|
||||
But how do blocks actually connect and work together? Understanding this is crucial for building effective modular workflows. Let's explore this through an example.
|
||||
|
||||
**How Blocks Connect in SequentialPipelineBlocks:**
|
||||
|
||||
The key insight is that blocks connect through their intermediate inputs and outputs - the "studs and anti-studs" we discussed earlier. Let's expand on our example to create a new block that produces `batch_size`, which we'll call "input_block":
|
||||
|
||||
```py
|
||||
def input_block_fn(block_state, pipeline_state):
|
||||
|
||||
batch_size = len(block_state.prompt)
|
||||
block_state.batch_size = batch_size * block_state.num_images_per_prompt
|
||||
|
||||
return block_state
|
||||
|
||||
input_block_cls = make_block(
|
||||
inputs=[
|
||||
InputParam(name="prompt", type_hint=list, description="list of text prompts"),
|
||||
InputParam(name="num_images_per_prompt", type_hint=int, description="number of images per prompt")
|
||||
],
|
||||
intermediate_outputs=[
|
||||
OutputParam(name="batch_size", description="calculated batch size")
|
||||
],
|
||||
block_fn=input_block_fn,
|
||||
description="A block that determines batch_size based on the number of prompts and num_images_per_prompt argument."
|
||||
)
|
||||
input_block = input_block_cls()
|
||||
```
|
||||
|
||||
Now let's connect these blocks to create a pipeline:
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
|
||||
# define a dict map block names to block class
|
||||
blocks_dict = InsertableDict()
|
||||
blocks_dict["input"] = input_block
|
||||
blocks_dict["image_encoder"] = image_encoder_block
|
||||
# create the multi-block
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
|
||||
# convert it to a runnable pipeline
|
||||
pipeline = blocks.init_pipeline()
|
||||
```
|
||||
|
||||
Now you have a pipeline with 2 blocks.
|
||||
|
||||
```py
|
||||
>>> pipeline.blocks
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
Description:
|
||||
|
||||
|
||||
Sub-Blocks:
|
||||
[0] input (TestBlock)
|
||||
Description: A block that determines batch_size based on the number of prompts and num_images_per_prompt argument.
|
||||
|
||||
[1] image_encoder (TestBlock)
|
||||
Description: Encode raw image into its latent presentation
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
When you inspect `pipeline.doc`, you can see that `batch_size` is not listed as an input. The pipeline automatically detects that the `input_block` can produce `batch_size` for the `image_encoder_block`, so it doesn't ask the user to provide it.
|
||||
|
||||
```py
|
||||
>>> print(pipeline.doc)
|
||||
class SequentialPipelineBlocks
|
||||
|
||||
Inputs:
|
||||
|
||||
prompt (`None`, *optional*):
|
||||
|
||||
num_images_per_prompt (`None`, *optional*):
|
||||
|
||||
image (`PIL.Image`, *optional*):
|
||||
raw input image to process
|
||||
|
||||
Outputs:
|
||||
|
||||
batch_size (`None`):
|
||||
|
||||
image_latents (`None`):
|
||||
latents representing the image
|
||||
```
|
||||
|
||||
At runtime, you have data flow like this:
|
||||
|
||||

|
||||
|
||||
**How SequentialPipelineBlocks Works:**
|
||||
|
||||
1. Blocks are executed in the order they're registered in the `blocks_dict`
|
||||
2. Outputs from one block become available as intermediate inputs to all subsequent blocks
|
||||
3. The pipeline automatically figures out which values need to be provided by the user and which will be generated by previous blocks
|
||||
4. Each block maintains its own behavior and operates through its defined interface, while collectively these interfaces determine what the entire pipeline accepts and produces
|
||||
|
||||
What happens within each block follows the same pattern we described earlier: each block gets its own `block_state` with the relevant inputs and intermediate inputs, performs its computation, and updates the pipeline state with its intermediate outputs.
|
||||
|
||||
## `LoopSequentialPipelineBlocks`
|
||||
|
||||
To create a loop in Modular Diffusers, you could use a single `PipelineBlock` like this:
|
||||
|
||||
```python
|
||||
class DenoiseLoop(PipelineBlock):
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
for t in range(block_state.num_inference_steps):
|
||||
# ... loop logic here
|
||||
pass
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
Or you could create a `LoopSequentialPipelineBlocks`. The key difference is that with `LoopSequentialPipelineBlocks`, the loop itself is modular: you can add or remove blocks within the loop or reuse the same loop structure with different block combinations.
|
||||
|
||||
It involves two parts: a **loop wrapper** and **loop blocks**
|
||||
|
||||
* The **loop wrapper** (`LoopSequentialPipelineBlocks`) defines the loop structure, e.g. it defines the iteration variables, and loop configurations such as progress bar.
|
||||
|
||||
* The **loop blocks** are basically standard pipeline blocks you add to the loop wrapper.
|
||||
- they run sequentially for each iteration of the loop
|
||||
- they receive the current iteration index as an additional parameter
|
||||
- they share the same block_state throughout the entire loop
|
||||
|
||||
Unlike regular `SequentialPipelineBlocks` where each block gets its own state, loop blocks share a single state that persists and evolves across iterations.
|
||||
|
||||
We will build a simple loop block to demonstrate these concepts. Creating a loop block involves three steps:
|
||||
1. defining the loop wrapper class
|
||||
2. creating the loop blocks
|
||||
3. adding the loop blocks to the loop wrapper class to create the loop wrapper instance
|
||||
|
||||
**Step 1: Define the Loop Wrapper**
|
||||
|
||||
To create a `LoopSequentialPipelineBlocks` class, you need to define:
|
||||
|
||||
* `loop_inputs`: User input variables (equivalent to `PipelineBlock.inputs`)
|
||||
* `loop_intermediate_inputs`: Intermediate variables needed from the mutable pipeline state (equivalent to `PipelineBlock.intermediates_inputs`)
|
||||
* `loop_intermediate_outputs`: New intermediate variables this block will add to the mutable pipeline state (equivalent to `PipelineBlock.intermediates_outputs`)
|
||||
* `__call__` method: Defines the loop structure and iteration logic
|
||||
|
||||
Here is an example of a loop wrapper:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import LoopSequentialPipelineBlocks, PipelineBlock, InputParam, OutputParam
|
||||
|
||||
class LoopWrapper(LoopSequentialPipelineBlocks):
|
||||
model_name = "test"
|
||||
@property
|
||||
def description(self):
|
||||
return "I'm a loop!!"
|
||||
@property
|
||||
def loop_inputs(self):
|
||||
return [InputParam(name="num_steps")]
|
||||
@torch.no_grad()
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
# Loop structure - can be customized to your needs
|
||||
for i in range(block_state.num_steps):
|
||||
# loop_step executes all registered blocks in sequence
|
||||
components, block_state = self.loop_step(components, block_state, i=i)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
**Step 2: Create Loop Blocks**
|
||||
|
||||
Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently:
|
||||
* It receives the iteration variable (e.g., `i`) passed by the loop wrapper
|
||||
* It works directly with `block_state` instead of pipeline state
|
||||
* No need to call `self.get_block_state()` or `self.set_block_state()`
|
||||
|
||||
```py
|
||||
class LoopBlock(PipelineBlock):
|
||||
# this is used to identify the model family, we won't worry about it in this example
|
||||
model_name = "test"
|
||||
@property
|
||||
def inputs(self):
|
||||
return [InputParam(name="x")]
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
# outputs produced by this block
|
||||
return [OutputParam(name="x")]
|
||||
@property
|
||||
def description(self):
|
||||
return "I'm a block used inside the `LoopWrapper` class"
|
||||
def __call__(self, components, block_state, i: int):
|
||||
block_state.x += 1
|
||||
return components, block_state
|
||||
```
|
||||
|
||||
**Step 3: Combine Everything**
|
||||
|
||||
Finally, assemble your loop by adding the block(s) to the wrapper:
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock})
|
||||
```
|
||||
|
||||
Now you've created a loop with one step:
|
||||
|
||||
```py
|
||||
>>> loop
|
||||
LoopWrapper(
|
||||
Class: LoopSequentialPipelineBlocks
|
||||
|
||||
Description: I'm a loop!!
|
||||
|
||||
Sub-Blocks:
|
||||
[0] block1 (LoopBlock)
|
||||
Description: I'm a block used inside the `LoopWrapper` class
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
It has two inputs: `x` (used at each step within the loop) and `num_steps` used to define the loop.
|
||||
|
||||
```py
|
||||
>>> print(loop.doc)
|
||||
class LoopWrapper
|
||||
|
||||
I'm a loop!!
|
||||
|
||||
Inputs:
|
||||
|
||||
x (`None`, *optional*):
|
||||
|
||||
num_steps (`None`, *optional*):
|
||||
|
||||
Outputs:
|
||||
|
||||
x (`None`):
|
||||
```
|
||||
|
||||
**Running the Loop:**
|
||||
|
||||
```py
|
||||
# run the loop
|
||||
loop_pipeline = loop.init_pipeline()
|
||||
x = loop_pipeline(num_steps=10, x=0, output="x")
|
||||
assert x == 10
|
||||
```
|
||||
|
||||
**Adding Multiple Blocks:**
|
||||
|
||||
We can add multiple blocks to run within each iteration. Let's run the loop block twice within each iteration:
|
||||
|
||||
```py
|
||||
loop = LoopWrapper.from_blocks_dict({"block1": LoopBlock(), "block2": LoopBlock})
|
||||
loop_pipeline = loop.init_pipeline()
|
||||
x = loop_pipeline(num_steps=10, x=0, output="x")
|
||||
assert x == 20 # Each iteration runs 2 blocks, so 10 iterations * 2 = 20
|
||||
```
|
||||
|
||||
**Key Differences from SequentialPipelineBlocks:**
|
||||
|
||||
The main difference is that loop blocks share the same `block_state` across all iterations, allowing values to accumulate and evolve throughout the loop. Loop blocks could receive additional arguments (like the current iteration index) depending on the loop wrapper's implementation, since the wrapper defines how loop blocks are called. You can easily add, remove, or reorder blocks within the loop without changing the loop logic itself.
|
||||
|
||||
The officially supported denoising loops in Modular Diffusers are implemented using `LoopSequentialPipelineBlocks`. You can explore the actual implementation to see how these concepts work in practice:
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl.denoise import StableDiffusionXLDenoiseStep
|
||||
StableDiffusionXLDenoiseStep()
|
||||
```
|
||||
|
||||
## `AutoPipelineBlocks`
|
||||
|
||||
`AutoPipelineBlocks` allows you to pack different pipelines into one and automatically select which one to run at runtime based on the inputs. The main purpose is convenience and portability - for developers, you can package everything into one workflow, making it easier to share and use.
|
||||
|
||||
For example, you might want to support text-to-image and image-to-image tasks. Instead of creating two separate pipelines, you can create an `AutoPipelineBlocks` that automatically chooses the workflow based on whether an `image` input is provided.
|
||||
|
||||
Let's see an example. Here we'll create a dummy `AutoPipelineBlocks` that includes dummy text-to-image, image-to-image, and inpaint pipelines.
|
||||
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import AutoPipelineBlocks
|
||||
|
||||
# These are dummy blocks and we only focus on "inputs" for our purpose
|
||||
inputs = [InputParam(name="prompt")]
|
||||
# block_fn prints out which workflow is running so we can see the execution order at runtime
|
||||
block_fn = lambda x, y: print("running the text-to-image workflow")
|
||||
block_t2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a text-to-image workflow!")
|
||||
|
||||
inputs = [InputParam(name="prompt"), InputParam(name="image")]
|
||||
block_fn = lambda x, y: print("running the image-to-image workflow")
|
||||
block_i2i_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a image-to-image workflow!")
|
||||
|
||||
inputs = [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
|
||||
block_fn = lambda x, y: print("running the inpaint workflow")
|
||||
block_inpaint_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a inpaint workflow!")
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# List of sub-block classes to choose from
|
||||
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
|
||||
# Names for each block in the same order
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
# Trigger inputs that determine which block to run
|
||||
# - "mask" triggers inpaint workflow
|
||||
# - "image" triggers img2img workflow (but only if mask is not provided)
|
||||
# - if none of above, runs the text2img workflow (default)
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
# Description is extremely important for AutoPipelineBlocks
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
|
||||
+ " - inpaint workflow is run when `mask` is provided.\n"
|
||||
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
|
||||
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
|
||||
)
|
||||
|
||||
# Create the blocks
|
||||
auto_blocks = AutoImageBlocks()
|
||||
# convert to pipeline
|
||||
auto_pipeline = auto_blocks.init_pipeline()
|
||||
```
|
||||
|
||||
Now we have created an `AutoPipelineBlocks` that contains 3 sub-blocks. Notice the warning message at the top - this automatically appears in every `ModularPipelineBlocks` that contains `AutoPipelineBlocks` to remind end users that dynamic block selection happens at runtime.
|
||||
|
||||
```py
|
||||
AutoImageBlocks(
|
||||
Class: AutoPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: ['mask', 'image']
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description: Pipeline generates images given different types of conditions!
|
||||
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
|
||||
- inpaint workflow is run when `mask` is provided.
|
||||
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
|
||||
- text2img workflow is run when neither `image` nor `mask` is provided.
|
||||
|
||||
|
||||
|
||||
Sub-Blocks:
|
||||
• inpaint [trigger: mask] (TestBlock)
|
||||
Description: I'm a inpaint workflow!
|
||||
|
||||
• img2img [trigger: image] (TestBlock)
|
||||
Description: I'm a image-to-image workflow!
|
||||
|
||||
• text2img [default] (TestBlock)
|
||||
Description: I'm a text-to-image workflow!
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Check out the documentation with `print(auto_pipeline.doc)`:
|
||||
|
||||
```py
|
||||
>>> print(auto_pipeline.doc)
|
||||
class AutoImageBlocks
|
||||
|
||||
Pipeline generates images given different types of conditions!
|
||||
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
|
||||
- inpaint workflow is run when `mask` is provided.
|
||||
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
|
||||
- text2img workflow is run when neither `image` nor `mask` is provided.
|
||||
|
||||
Inputs:
|
||||
|
||||
prompt (`None`, *optional*):
|
||||
|
||||
image (`None`, *optional*):
|
||||
|
||||
mask (`None`, *optional*):
|
||||
```
|
||||
|
||||
There is a fundamental trade-off of AutoPipelineBlocks: it trades clarity for convenience. While it is really easy for packaging multiple workflows, it can become confusing without proper documentation. e.g. if we just throw a pipeline at you and tell you that it contains 3 sub-blocks and takes 3 inputs `prompt`, `image` and `mask`, and ask you to run an image-to-image workflow: if you don't have any prior knowledge on how these pipelines work, you would be pretty clueless, right?
|
||||
|
||||
This pipeline we just made though, has a docstring that shows all available inputs and workflows and explains how to use each with different inputs. So it's really helpful for users. For example, it's clear that you need to pass `image` to run img2img. This is why the description field is absolutely critical for AutoPipelineBlocks. We highly recommend you to explain the conditional logic very well for each `AutoPipelineBlocks` you would make. We also recommend to always test individual pipelines first before packaging them into AutoPipelineBlocks.
|
||||
|
||||
Let's run this auto pipeline with different inputs to see if the conditional logic works as described. Remember that we have added `print` in each `PipelineBlock`'s `__call__` method to print out its workflow name, so it should be easy to tell which one is running:
|
||||
|
||||
```py
|
||||
>>> _ = auto_pipeline(image="image", mask="mask")
|
||||
running the inpaint workflow
|
||||
>>> _ = auto_pipeline(image="image")
|
||||
running the image-to-image workflow
|
||||
>>> _ = auto_pipeline(prompt="prompt")
|
||||
running the text-to-image workflow
|
||||
>>> _ = auto_pipeline(image="prompt", mask="mask")
|
||||
running the inpaint workflow
|
||||
```
|
||||
|
||||
However, even with documentation, it can become very confusing when AutoPipelineBlocks are combined with other blocks. The complexity grows quickly when you have nested AutoPipelineBlocks or use them as sub-blocks in larger pipelines.
|
||||
|
||||
Let's make another `AutoPipelineBlocks` - this one only contains one block, and it does not include `None` in its `block_trigger_inputs` (which corresponds to the default block to run when none of the trigger inputs are provided). This means this block will be skipped if the trigger input (`ip_adapter_image`) is not provided at runtime.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
|
||||
inputs = [InputParam(name="ip_adapter_image")]
|
||||
block_fn = lambda x, y: print("running the ip-adapter workflow")
|
||||
block_ipa_cls = make_block(inputs=inputs, block_fn=block_fn, description="I'm a IP-adapter workflow!")
|
||||
|
||||
class AutoIPAdapter(AutoPipelineBlocks):
|
||||
block_classes = [block_ipa_cls]
|
||||
block_names = ["ip-adapter"]
|
||||
block_trigger_inputs = ["ip_adapter_image"]
|
||||
@property
|
||||
def description(self):
|
||||
return "Run IP Adapter step if `ip_adapter_image` is provided."
|
||||
```
|
||||
|
||||
Now let's combine these 2 auto blocks together into a `SequentialPipelineBlocks`:
|
||||
|
||||
```py
|
||||
auto_ipa_blocks = AutoIPAdapter()
|
||||
blocks_dict = InsertableDict()
|
||||
blocks_dict["ip-adapter"] = auto_ipa_blocks
|
||||
blocks_dict["image-generation"] = auto_blocks
|
||||
all_blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
|
||||
pipeline = all_blocks.init_pipeline()
|
||||
```
|
||||
|
||||
Let's take a look: now things get more confusing. In this particular example, you could still try to explain the conditional logic in the `description` field here - there are only 4 possible execution paths so it's doable. However, since this is a `SequentialPipelineBlocks` that could contain many more blocks, the complexity can quickly get out of hand as the number of blocks increases.
|
||||
|
||||
```py
|
||||
>>> all_blocks
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
====================================================================================================
|
||||
This pipeline contains blocks that are selected at runtime based on inputs.
|
||||
Trigger Inputs: ['image', 'mask', 'ip_adapter_image']
|
||||
Use `get_execution_blocks()` with input names to see selected blocks (e.g. `get_execution_blocks('image')`).
|
||||
====================================================================================================
|
||||
|
||||
|
||||
Description:
|
||||
|
||||
|
||||
Sub-Blocks:
|
||||
[0] ip-adapter (AutoIPAdapter)
|
||||
Description: Run IP Adapter step if `ip_adapter_image` is provided.
|
||||
|
||||
|
||||
[1] image-generation (AutoImageBlocks)
|
||||
Description: Pipeline generates images given different types of conditions!
|
||||
This is an auto pipeline block that works for text2img, img2img and inpainting tasks.
|
||||
- inpaint workflow is run when `mask` is provided.
|
||||
- img2img workflow is run when `image` is provided (but only when `mask` is not provided).
|
||||
- text2img workflow is run when neither `image` nor `mask` is provided.
|
||||
|
||||
|
||||
)
|
||||
|
||||
```
|
||||
|
||||
This is when the `get_execution_blocks()` method comes in handy - it basically extracts a `SequentialPipelineBlocks` that only contains the blocks that are actually run based on your inputs.
|
||||
|
||||
Let's try some examples:
|
||||
|
||||
`mask`: we expect it to skip the first ip-adapter since `ip_adapter_image` is not provided, and then run the inpaint for the second block.
|
||||
|
||||
```py
|
||||
>>> all_blocks.get_execution_blocks('mask')
|
||||
SequentialPipelineBlocks(
|
||||
Class: ModularPipelineBlocks
|
||||
|
||||
Description:
|
||||
|
||||
|
||||
Sub-Blocks:
|
||||
[0] image-generation (TestBlock)
|
||||
Description: I'm a inpaint workflow!
|
||||
|
||||
)
|
||||
```
|
||||
|
||||
Let's also actually run the pipeline to confirm:
|
||||
|
||||
```py
|
||||
>>> _ = pipeline(mask="mask")
|
||||
skipping auto block: AutoIPAdapter
|
||||
running the inpaint workflow
|
||||
```
|
||||
|
||||
Try a few more:
|
||||
|
||||
```py
|
||||
print(f"inputs: ip_adapter_image:")
|
||||
blocks_select = all_blocks.get_execution_blocks('ip_adapter_image')
|
||||
print(f"expected_execution_blocks: {blocks_select}")
|
||||
print(f"actual execution blocks:")
|
||||
_ = pipeline(ip_adapter_image="ip_adapter_image", prompt="prompt")
|
||||
# expect to see ip-adapter + text2img
|
||||
|
||||
print(f"inputs: image:")
|
||||
blocks_select = all_blocks.get_execution_blocks('image')
|
||||
print(f"expected_execution_blocks: {blocks_select}")
|
||||
print(f"actual execution blocks:")
|
||||
_ = pipeline(image="image", prompt="prompt")
|
||||
# expect to see img2img
|
||||
|
||||
print(f"inputs: prompt:")
|
||||
blocks_select = all_blocks.get_execution_blocks('prompt')
|
||||
print(f"expected_execution_blocks: {blocks_select}")
|
||||
print(f"actual execution blocks:")
|
||||
_ = pipeline(prompt="prompt")
|
||||
# expect to see text2img (prompt is not a trigger input so fallback to default)
|
||||
|
||||
print(f"inputs: mask + ip_adapter_image:")
|
||||
blocks_select = all_blocks.get_execution_blocks('mask','ip_adapter_image')
|
||||
print(f"expected_execution_blocks: {blocks_select}")
|
||||
print(f"actual execution blocks:")
|
||||
_ = pipeline(mask="mask", ip_adapter_image="ip_adapter_image")
|
||||
# expect to see ip-adapter + inpaint
|
||||
```
|
||||
|
||||
In summary, `AutoPipelineBlocks` is a good tool for packaging multiple workflows into a single, convenient interface and it can greatly simplify the user experience. However, always provide clear descriptions explaining the conditional logic, test individual pipelines first before combining them, and use `get_execution_blocks()` to understand runtime behavior in complex compositions.
|
||||
@@ -315,6 +315,8 @@ pipeline.load_lora_weights(
|
||||
> [!TIP]
|
||||
> Move your code inside the `with torch._dynamo.config.patch(error_on_recompile=True)` context manager to detect if a model was recompiled. If a model is recompiled despite following all the steps above, please open an [issue](https://github.com/huggingface/diffusers/issues) with a reproducible example.
|
||||
|
||||
If you expect to varied resolutions during inference with this feature, then make sure set `dynamic=True` during compilation. Refer to [this document](../optimization/fp16#dynamic-shape-compilation) for more details.
|
||||
|
||||
There are still scenarios where recompulation is unavoidable, such as when the hotswapped LoRA targets more layers than the initial adapter. Try to load the LoRA that targets the most layers *first*. For more details about this limitation, refer to the PEFT [hotswapping](https://huggingface.co/docs/peft/main/en/package_reference/hotswap#peft.utils.hotswap.hotswap_adapter) docs.
|
||||
|
||||
## Merge
|
||||
|
||||
264
docs/source/en/using-diffusers/batched_inference.md
Normal file
264
docs/source/en/using-diffusers/batched_inference.md
Normal file
@@ -0,0 +1,264 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Batch inference
|
||||
|
||||
Batch inference processes multiple prompts at a time to increase throughput. It is more efficient because processing multiple prompts at once maximizes GPU usage versus processing a single prompt and underutilizing the GPU.
|
||||
|
||||
The downside is increased latency because you must wait for the entire batch to complete, and more GPU memory is required for large batches.
|
||||
|
||||
<hfoptions id="usage">
|
||||
<hfoption id="text-to-image">
|
||||
|
||||
For text-to-image, pass a list of prompts to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
prompts = [
|
||||
"cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
|
||||
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
|
||||
]
|
||||
|
||||
images = pipeline(
|
||||
prompt=prompts,
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
images = pipeline(
|
||||
prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
|
||||
num_images_per_prompt=4
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
Combine both approaches to generate different variations of different prompts.
|
||||
|
||||
```py
|
||||
images = pipeline(
|
||||
prompt=prompts,
|
||||
num_images_per_prompt=2,
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="image-to-image">
|
||||
|
||||
For image-to-image, pass a list of input images and prompts to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.utils import load_image
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
input_images = [
|
||||
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png"),
|
||||
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
|
||||
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"cinematic photo of a beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
|
||||
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
|
||||
]
|
||||
|
||||
images = pipeline(
|
||||
prompt=prompts,
|
||||
image=input_images,
|
||||
guidance_scale=8.0,
|
||||
strength=0.5
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
To generate multiple variations of one prompt, use the `num_images_per_prompt` argument.
|
||||
|
||||
```py
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
from diffusers.utils import load_image
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
input_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
|
||||
|
||||
images = pipeline(
|
||||
prompt="pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics",
|
||||
image=input_image,
|
||||
num_images_per_prompt=4
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
Combine both approaches to generate different variations of different prompts.
|
||||
|
||||
```py
|
||||
input_images = [
|
||||
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"),
|
||||
load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/detail-prompt.png")
|
||||
]
|
||||
|
||||
prompts = [
|
||||
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
|
||||
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
|
||||
]
|
||||
|
||||
images = pipeline(
|
||||
prompt=prompts,
|
||||
image=input_images,
|
||||
num_images_per_prompt=2,
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Deterministic generation
|
||||
|
||||
Enable reproducible batch generation by passing a list of [Generator’s](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed to reuse it.
|
||||
|
||||
Use a list comprehension to iterate over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch.
|
||||
|
||||
Don't multiply the `Generator` by the batch size because that only creates one `Generator` object that is used sequentially for each image in the batch.
|
||||
|
||||
```py
|
||||
generator = [torch.Generator(device="cuda").manual_seed(0)] * 3
|
||||
```
|
||||
|
||||
Pass the `generator` to the pipeline.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
|
||||
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(3)]
|
||||
prompts = [
|
||||
"cinematic photo of A beautiful sunset over mountains, 35mm photograph, film, professional, 4k, highly detailed",
|
||||
"cinematic film still of a cat basking in the sun on a roof in Turkey, highly detailed, high budget hollywood movie, cinemascope, moody, epic, gorgeous, film grain",
|
||||
"pixel-art a cozy coffee shop interior, low-res, blocky, pixel art style, 8-bit graphics"
|
||||
]
|
||||
|
||||
images = pipeline(
|
||||
prompt=prompts,
|
||||
generator=generator
|
||||
).images
|
||||
|
||||
fig, axes = plt.subplots(2, 2, figsize=(12, 12))
|
||||
axes = axes.flatten()
|
||||
|
||||
for i, image in enumerate(images):
|
||||
axes[i].imshow(image)
|
||||
axes[i].set_title(f"Image {i+1}")
|
||||
axes[i].axis('off')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
```
|
||||
|
||||
You can use this to iteratively select an image associated with a seed and then improve on it by crafting a more detailed prompt.
|
||||
@@ -70,41 +70,32 @@ pipeline = StableDiffusionPipeline.from_single_file(
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
#### LoRA files
|
||||
#### LoRAs
|
||||
|
||||
[LoRA](https://hf.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora) is a lightweight adapter that is fast and easy to train, making them especially popular for generating images in a certain way or style. These adapters are commonly stored in a safetensors file, and are widely popular on model sharing platforms like [civitai](https://civitai.com/).
|
||||
[LoRAs](../tutorials/using_peft_for_inference) are lightweight checkpoints fine-tuned to generate images or video in a specific style. If you are using a checkpoint trained with a Diffusers training script, the LoRA configuration is automatically saved as metadata in a safetensors file. When the safetensors file is loaded, the metadata is parsed to correctly configure the LoRA and avoids missing or incorrect LoRA configurations.
|
||||
|
||||
LoRAs are loaded into a base model with the [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] method.
|
||||
|
||||
```py
|
||||
from diffusers import StableDiffusionXLPipeline
|
||||
import torch
|
||||
|
||||
# base model
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
"Lykon/dreamshaper-xl-1-0", torch_dtype=torch.float16, variant="fp16"
|
||||
).to("cuda")
|
||||
|
||||
# download LoRA weights
|
||||
!wget https://civitai.com/api/download/models/168776 -O blueprintify.safetensors
|
||||
|
||||
# load LoRA weights
|
||||
pipeline.load_lora_weights(".", weight_name="blueprintify.safetensors")
|
||||
prompt = "bl3uprint, a highly detailed blueprint of the empire state building, explaining how to build all parts, many txt, blueprint grid backdrop"
|
||||
negative_prompt = "lowres, cropped, worst quality, low quality, normal quality, artifacts, signature, watermark, username, blurry, more than one bridge, bad architecture"
|
||||
|
||||
image = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
generator=torch.manual_seed(0),
|
||||
).images[0]
|
||||
image
|
||||
```
|
||||
The easiest way to inspect the metadata, if available, is by clicking on the Safetensors logo next to the weights.
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/blueprint-lora.png"/>
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/safetensors_lora.png"/>
|
||||
</div>
|
||||
|
||||
For LoRAs that aren't trained with Diffusers, you can still save metadata with the `transformer_lora_adapter_metadata` and `text_encoder_lora_adapter_metadata` arguments in [`~loaders.FluxLoraLoaderMixin.save_lora_weights`] as long as it is a safetensors file.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
|
||||
).to("cuda")
|
||||
pipeline.load_lora_weights("linoyts/yarn_art_Flux_LoRA")
|
||||
pipeline.save_lora_weights(
|
||||
transformer_lora_adapter_metadata={"r": 16, "lora_alpha": 16},
|
||||
text_encoder_lora_adapter_metadata={"r": 8, "lora_alpha": 8}
|
||||
)
|
||||
```
|
||||
|
||||
### ckpt
|
||||
|
||||
> [!WARNING]
|
||||
|
||||
@@ -136,53 +136,3 @@ result2 = pipe(prompt=prompt, num_inference_steps=50, generator=g, output_type="
|
||||
print("L_inf dist =", abs(result1 - result2).max())
|
||||
"L_inf dist = tensor(0., device='cuda:0')"
|
||||
```
|
||||
|
||||
## Deterministic batch generation
|
||||
|
||||
A practical application of creating reproducible pipelines is *deterministic batch generation*. You generate a batch of images and select one image to improve with a more detailed prompt. The main idea is to pass a list of [Generator's](https://pytorch.org/docs/stable/generated/torch.Generator.html) to the pipeline and tie each `Generator` to a seed so you can reuse it.
|
||||
|
||||
Let's use the [stable-diffusion-v1-5/stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) checkpoint and generate a batch of images.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.utils import make_image_grid
|
||||
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
"stable-diffusion-v1-5/stable-diffusion-v1-5", torch_dtype=torch.float16, use_safetensors=True
|
||||
)
|
||||
pipeline = pipeline.to("cuda")
|
||||
```
|
||||
|
||||
Define four different `Generator`s and assign each `Generator` a seed (`0` to `3`). Then generate a batch of images and pick one to iterate on.
|
||||
|
||||
> [!WARNING]
|
||||
> Use a list comprehension that iterates over the batch size specified in `range()` to create a unique `Generator` object for each image in the batch. If you multiply the `Generator` by the batch size integer, it only creates *one* `Generator` object that is used sequentially for each image in the batch.
|
||||
>
|
||||
> ```py
|
||||
> [torch.Generator().manual_seed(seed)] * 4
|
||||
> ```
|
||||
|
||||
```python
|
||||
generator = [torch.Generator(device="cuda").manual_seed(i) for i in range(4)]
|
||||
prompt = "Labrador in the style of Vermeer"
|
||||
images = pipeline(prompt, generator=generator, num_images_per_prompt=4).images[0]
|
||||
make_image_grid(images, rows=2, cols=2)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds.jpg"/>
|
||||
</div>
|
||||
|
||||
Let's improve the first image (you can choose any image you want) which corresponds to the `Generator` with seed `0`. Add some additional text to your prompt and then make sure you reuse the same `Generator` with seed `0`. All the generated images should resemble the first image.
|
||||
|
||||
```python
|
||||
prompt = [prompt + t for t in [", highly realistic", ", artsy", ", trending", ", colorful"]]
|
||||
generator = [torch.Generator(device="cuda").manual_seed(0) for i in range(4)]
|
||||
images = pipeline(prompt, generator=generator).images
|
||||
make_image_grid(images, rows=2, cols=2)
|
||||
```
|
||||
|
||||
<div class="flex justify-center">
|
||||
<img src="https://huggingface.co/datasets/diffusers/diffusers-images-docs/resolve/main/reusabe_seeds_2.jpg"/>
|
||||
</div>
|
||||
|
||||
@@ -242,3 +242,15 @@ unet = UNet2DConditionModel.from_pretrained(
|
||||
)
|
||||
unet.save_pretrained("./local-unet", variant="non_ema")
|
||||
```
|
||||
|
||||
Use the `torch_dtype` argument in [`~ModelMixin.from_pretrained`] to specify the dtype to load a model in.
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet = AutoModel.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16
|
||||
)
|
||||
```
|
||||
|
||||
You can also use the [torch.Tensor.to](https://docs.pytorch.org/docs/stable/generated/torch.Tensor.to.html) method to convert to the specified dtype on the fly. It converts *all* weights unlike the `torch_dtype` argument that respects the `_keep_in_fp32_modules`. This is important for models whose layers must remain in fp32 for numerical stability and best generation quality (see example [here](https://github.com/huggingface/diffusers/blob/f864a9a352fa4a220d860bfdd1782e3e5af96382/src/diffusers/models/transformers/transformer_wan.py#L374)).
|
||||
|
||||
@@ -263,9 +263,19 @@ This reduces memory requirements significantly w/o a significant quality loss. N
|
||||
## Training Kontext
|
||||
|
||||
[Kontext](https://bfl.ai/announcements/flux-1-kontext) lets us perform image editing as well as image generation. Even though it can accept both image and text as inputs, one can use it for text-to-image (T2I) generation, too. We
|
||||
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for T2I. The optimizations discussed above apply this script, too.
|
||||
provide a simple script for LoRA fine-tuning Kontext in [train_dreambooth_lora_flux_kontext.py](./train_dreambooth_lora_flux_kontext.py) for both T2I and I2I. The optimizations discussed above apply this script, too.
|
||||
|
||||
Make sure to follow the [instructions to set up your environment](#running-locally-with-pytorch) before proceeding to the rest of the section.
|
||||
**important**
|
||||
|
||||
> [!NOTE]
|
||||
> To make sure you can successfully run the latest version of the kontext example script, we highly recommend installing from source, specifically from the commit mentioned below.
|
||||
> To do this, execute the following steps in a new virtual environment:
|
||||
> ```
|
||||
> git clone https://github.com/huggingface/diffusers
|
||||
> cd diffusers
|
||||
> git checkout 05e7a854d0a5661f5b433f6dd5954c224b104f0b
|
||||
> pip install -e .
|
||||
> ```
|
||||
|
||||
Below is an example training command:
|
||||
|
||||
@@ -294,6 +304,42 @@ accelerate launch train_dreambooth_lora_flux_kontext.py \
|
||||
Fine-tuning Kontext on the T2I task can be useful when working with specific styles/subjects where it may not
|
||||
perform as expected.
|
||||
|
||||
Image-guided fine-tuning (I2I) is also supported. To start, you must have a dataset containing triplets:
|
||||
|
||||
* Condition image
|
||||
* Target image
|
||||
* Instruction
|
||||
|
||||
[kontext-community/relighting](https://huggingface.co/datasets/kontext-community/relighting) is a good example of such a dataset. If you are using such a dataset, you can use the command below to launch training:
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora_flux_kontext.py \
|
||||
--pretrained_model_name_or_path=black-forest-labs/FLUX.1-Kontext-dev \
|
||||
--output_dir="kontext-i2i" \
|
||||
--dataset_name="kontext-community/relighting" \
|
||||
--image_column="output" --cond_image_column="file_name" --caption_column="instruction" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--guidance_scale=1 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--optimizer="adamw" \
|
||||
--use_8bit_adam \
|
||||
--cache_latents \
|
||||
--learning_rate=1e-4 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=200 \
|
||||
--max_train_steps=1000 \
|
||||
--rank=16\
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
More generally, when performing I2I fine-tuning, we expect you to:
|
||||
|
||||
* Have a dataset `kontext-community/relighting`
|
||||
* Supply `image_column`, `cond_image_column`, and `caption_column` values when launching training
|
||||
|
||||
### Misc notes
|
||||
|
||||
* By default, we use `mode` as the value of `--vae_encode_mode` argument. This is because Kontext uses `mode()` of the distribution predicted by the VAE instead of sampling from it.
|
||||
@@ -307,4 +353,4 @@ To enable aspect ratio bucketing, pass `--aspect_ratio_buckets` argument with a
|
||||
Since Flux Kontext finetuning is still an experimental phase, we encourage you to explore different settings and share your insights! 🤗
|
||||
|
||||
## Other notes
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
Thanks to `bghira` and `ostris` for their help with reviewing & insight sharing ♥️
|
||||
|
||||
@@ -40,7 +40,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data.sampler import BatchSampler
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from torchvision.transforms import functional as TF
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast
|
||||
|
||||
@@ -62,11 +62,7 @@ from diffusers.training_utils import (
|
||||
free_memory,
|
||||
parse_buckets_string,
|
||||
)
|
||||
from diffusers.utils import (
|
||||
check_min_version,
|
||||
convert_unet_state_dict_to_peft,
|
||||
is_wandb_available,
|
||||
)
|
||||
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, is_wandb_available, load_image
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_torch_npu_available
|
||||
from diffusers.utils.torch_utils import is_compiled_module
|
||||
@@ -186,6 +182,7 @@ def log_validation(
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
pipeline_args_cp = pipeline_args.copy()
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed is not None else None
|
||||
@@ -193,14 +190,16 @@ def log_validation(
|
||||
|
||||
# pre-calculate prompt embeds, pooled prompt embeds, text ids because t5 does not support autocast
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(
|
||||
pipeline_args["prompt"], prompt_2=pipeline_args["prompt"]
|
||||
)
|
||||
prompt = pipeline_args_cp.pop("prompt")
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = pipeline.encode_prompt(prompt, prompt_2=None)
|
||||
images = []
|
||||
for _ in range(args.num_validation_images):
|
||||
with autocast_ctx:
|
||||
image = pipeline(
|
||||
prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, generator=generator
|
||||
**pipeline_args_cp,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
generator=generator,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
@@ -310,6 +309,12 @@ def parse_args(input_args=None):
|
||||
"default, the standard Image Dataset maps out 'file_name' "
|
||||
"to 'image'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cond_image_column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Column in the dataset containing the condition image. Must be specified when performing I2I fine-tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
@@ -330,7 +335,6 @@ def parse_args(input_args=None):
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -351,6 +355,12 @@ def parse_args(input_args=None):
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_image",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Validation image to use (during I2I fine-tuning) to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
@@ -399,7 +409,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="flux-dreambooth-lora",
|
||||
default="flux-kontext-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
@@ -716,6 +726,8 @@ def parse_args(input_args=None):
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
if args.cond_image_column is not None:
|
||||
raise ValueError("Prior preservation isn't supported with I2I training.")
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
@@ -723,6 +735,14 @@ def parse_args(input_args=None):
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
if args.cond_image_column is not None:
|
||||
assert args.image_column is not None
|
||||
assert args.caption_column is not None
|
||||
assert args.dataset_name is not None
|
||||
assert not args.train_text_encoder
|
||||
if args.validation_prompt is not None:
|
||||
assert args.validation_image is None and os.path.exists(args.validation_image)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@@ -742,6 +762,7 @@ class DreamBoothDataset(Dataset):
|
||||
repeats=1,
|
||||
center_crop=False,
|
||||
buckets=None,
|
||||
args=None,
|
||||
):
|
||||
self.center_crop = center_crop
|
||||
|
||||
@@ -774,6 +795,10 @@ class DreamBoothDataset(Dataset):
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.cond_image_column is not None and args.cond_image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--cond_image_column` value '{args.cond_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
if args.image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
@@ -783,7 +808,12 @@ class DreamBoothDataset(Dataset):
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
instance_images = [dataset["train"][i][image_column] for i in range(len(dataset["train"]))]
|
||||
cond_images = None
|
||||
cond_image_column = args.cond_image_column
|
||||
if cond_image_column is not None:
|
||||
cond_images = [dataset["train"][i][cond_image_column] for i in range(len(dataset["train"]))]
|
||||
assert len(instance_images) == len(cond_images)
|
||||
|
||||
if args.caption_column is None:
|
||||
logger.info(
|
||||
@@ -811,14 +841,23 @@ class DreamBoothDataset(Dataset):
|
||||
self.custom_instance_prompts = None
|
||||
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.cond_images = []
|
||||
for i, img in enumerate(instance_images):
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
if args.dataset_name is not None and cond_images is not None:
|
||||
self.cond_images.extend(itertools.repeat(cond_images[i], repeats))
|
||||
|
||||
self.pixel_values = []
|
||||
for image in self.instance_images:
|
||||
self.cond_pixel_values = []
|
||||
for i, image in enumerate(self.instance_images):
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
dest_image = None
|
||||
if self.cond_images:
|
||||
dest_image = exif_transpose(self.cond_images[i])
|
||||
if not dest_image.mode == "RGB":
|
||||
dest_image = dest_image.convert("RGB")
|
||||
|
||||
width, height = image.size
|
||||
|
||||
@@ -828,25 +867,16 @@ class DreamBoothDataset(Dataset):
|
||||
self.size = (target_height, target_width)
|
||||
|
||||
# based on the bucket assignment, define the transformations
|
||||
train_resize = transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(self.size) if center_crop else transforms.RandomCrop(self.size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
image, dest_image = self.paired_transform(
|
||||
image,
|
||||
dest_image=dest_image,
|
||||
size=self.size,
|
||||
center_crop=args.center_crop,
|
||||
random_flip=args.random_flip,
|
||||
)
|
||||
image = train_resize(image)
|
||||
if args.center_crop:
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, self.size)
|
||||
image = crop(image, y1, x1, h, w)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
image = train_flip(image)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append((image, bucket_idx))
|
||||
if dest_image is not None:
|
||||
self.cond_pixel_values.append((dest_image, bucket_idx))
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
@@ -880,6 +910,9 @@ class DreamBoothDataset(Dataset):
|
||||
instance_image, bucket_idx = self.pixel_values[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["bucket_idx"] = bucket_idx
|
||||
if self.cond_pixel_values:
|
||||
dest_image, _ = self.cond_pixel_values[index % self.num_instance_images]
|
||||
example["cond_images"] = dest_image
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -902,6 +935,43 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
return example
|
||||
|
||||
def paired_transform(self, image, dest_image=None, size=(224, 224), center_crop=False, random_flip=False):
|
||||
# 1. Resize (deterministic)
|
||||
resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
image = resize(image)
|
||||
if dest_image is not None:
|
||||
dest_image = resize(dest_image)
|
||||
|
||||
# 2. Crop: either center or SAME random crop
|
||||
if center_crop:
|
||||
crop = transforms.CenterCrop(size)
|
||||
image = crop(image)
|
||||
if dest_image is not None:
|
||||
dest_image = crop(dest_image)
|
||||
else:
|
||||
# get_params returns (i, j, h, w)
|
||||
i, j, h, w = transforms.RandomCrop.get_params(image, output_size=size)
|
||||
image = TF.crop(image, i, j, h, w)
|
||||
if dest_image is not None:
|
||||
dest_image = TF.crop(dest_image, i, j, h, w)
|
||||
|
||||
# 3. Random horizontal flip with the SAME coin flip
|
||||
if random_flip:
|
||||
do_flip = random.random() < 0.5
|
||||
if do_flip:
|
||||
image = TF.hflip(image)
|
||||
if dest_image is not None:
|
||||
dest_image = TF.hflip(dest_image)
|
||||
|
||||
# 4. ToTensor + Normalize (deterministic)
|
||||
to_tensor = transforms.ToTensor()
|
||||
normalize = transforms.Normalize([0.5], [0.5])
|
||||
image = normalize(to_tensor(image))
|
||||
if dest_image is not None:
|
||||
dest_image = normalize(to_tensor(dest_image))
|
||||
|
||||
return (image, dest_image) if dest_image is not None else (image, None)
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
@@ -917,6 +987,11 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
if any("cond_images" in example for example in examples):
|
||||
cond_pixel_values = [example["cond_images"] for example in examples]
|
||||
cond_pixel_values = torch.stack(cond_pixel_values)
|
||||
cond_pixel_values = cond_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
batch.update({"cond_pixel_values": cond_pixel_values})
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1318,6 +1393,7 @@ def main(args):
|
||||
"ff.net.2",
|
||||
"ff_context.net.0.proj",
|
||||
"ff_context.net.2",
|
||||
"proj_mlp",
|
||||
]
|
||||
|
||||
# now we will add new LoRA weights the transformer layers
|
||||
@@ -1534,7 +1610,10 @@ def main(args):
|
||||
buckets=buckets,
|
||||
repeats=args.repeats,
|
||||
center_crop=args.center_crop,
|
||||
args=args,
|
||||
)
|
||||
if args.cond_image_column is not None:
|
||||
logger.info("I2I fine-tuning enabled.")
|
||||
batch_sampler = BucketBatchSampler(train_dataset, batch_size=args.train_batch_size, drop_last=False)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
@@ -1574,6 +1653,7 @@ def main(args):
|
||||
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
text_encoder_one.cpu(), text_encoder_two.cpu()
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
@@ -1605,19 +1685,41 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
elif train_dataset.custom_instance_prompts and not args.train_text_encoder:
|
||||
cached_text_embeddings = []
|
||||
for batch in tqdm(train_dataloader, desc="Embedding prompts"):
|
||||
batch_prompts = batch["prompts"]
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
batch_prompts, text_encoders, tokenizers
|
||||
)
|
||||
cached_text_embeddings.append((prompt_embeds, pooled_prompt_embeds, text_ids))
|
||||
|
||||
if args.validation_prompt is None:
|
||||
text_encoder_one.cpu(), text_encoder_two.cpu()
|
||||
del text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two
|
||||
free_memory()
|
||||
|
||||
vae_config_shift_factor = vae.config.shift_factor
|
||||
vae_config_scaling_factor = vae.config.scaling_factor
|
||||
vae_config_block_out_channels = vae.config.block_out_channels
|
||||
has_image_input = args.cond_image_column is not None
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
cond_latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
with torch.no_grad():
|
||||
batch["pixel_values"] = batch["pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
latents_cache.append(vae.encode(batch["pixel_values"]).latent_dist)
|
||||
if has_image_input:
|
||||
batch["cond_pixel_values"] = batch["cond_pixel_values"].to(
|
||||
accelerator.device, non_blocking=True, dtype=weight_dtype
|
||||
)
|
||||
cond_latents_cache.append(vae.encode(batch["cond_pixel_values"]).latent_dist)
|
||||
|
||||
if args.validation_prompt is None:
|
||||
vae.cpu()
|
||||
del vae
|
||||
free_memory()
|
||||
|
||||
@@ -1678,7 +1780,7 @@ def main(args):
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_name = "dreambooth-flux-dev-lora"
|
||||
tracker_name = "dreambooth-flux-kontext-lora"
|
||||
accelerator.init_trackers(tracker_name, config=vars(args))
|
||||
|
||||
# Train!
|
||||
@@ -1742,6 +1844,7 @@ def main(args):
|
||||
sigma = sigma.unsqueeze(-1)
|
||||
return sigma
|
||||
|
||||
has_guidance = unwrap_model(transformer).config.guidance_embeds
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
@@ -1759,9 +1862,7 @@ def main(args):
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = cached_text_embeddings[step]
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts, max_sequence_length=77)
|
||||
tokens_two = tokenize_prompt(
|
||||
@@ -1794,16 +1895,29 @@ def main(args):
|
||||
if args.cache_latents:
|
||||
if args.vae_encode_mode == "sample":
|
||||
model_input = latents_cache[step].sample()
|
||||
if has_image_input:
|
||||
cond_model_input = cond_latents_cache[step].sample()
|
||||
else:
|
||||
model_input = latents_cache[step].mode()
|
||||
if has_image_input:
|
||||
cond_model_input = cond_latents_cache[step].mode()
|
||||
else:
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
if has_image_input:
|
||||
cond_pixel_values = batch["cond_pixel_values"].to(dtype=vae.dtype)
|
||||
if args.vae_encode_mode == "sample":
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
if has_image_input:
|
||||
cond_model_input = vae.encode(cond_pixel_values).latent_dist.sample()
|
||||
else:
|
||||
model_input = vae.encode(pixel_values).latent_dist.mode()
|
||||
if has_image_input:
|
||||
cond_model_input = vae.encode(cond_pixel_values).latent_dist.mode()
|
||||
model_input = (model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
if has_image_input:
|
||||
cond_model_input = (cond_model_input - vae_config_shift_factor) * vae_config_scaling_factor
|
||||
cond_model_input = cond_model_input.to(dtype=weight_dtype)
|
||||
|
||||
vae_scale_factor = 2 ** (len(vae_config_block_out_channels) - 1)
|
||||
|
||||
@@ -1814,6 +1928,17 @@ def main(args):
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
if has_image_input:
|
||||
cond_latents_ids = FluxKontextPipeline._prepare_latent_image_ids(
|
||||
cond_model_input.shape[0],
|
||||
cond_model_input.shape[2] // 2,
|
||||
cond_model_input.shape[3] // 2,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
)
|
||||
cond_latents_ids[..., 0] = 1
|
||||
latent_image_ids = torch.cat([latent_image_ids, cond_latents_ids], dim=0)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
@@ -1834,7 +1959,6 @@ def main(args):
|
||||
# zt = (1 - texp) * x + texp * z1
|
||||
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
packed_noisy_model_input = FluxKontextPipeline._pack_latents(
|
||||
noisy_model_input,
|
||||
batch_size=model_input.shape[0],
|
||||
@@ -1842,13 +1966,22 @@ def main(args):
|
||||
height=model_input.shape[2],
|
||||
width=model_input.shape[3],
|
||||
)
|
||||
orig_inp_shape = packed_noisy_model_input.shape
|
||||
if has_image_input:
|
||||
packed_cond_input = FluxKontextPipeline._pack_latents(
|
||||
cond_model_input,
|
||||
batch_size=cond_model_input.shape[0],
|
||||
num_channels_latents=cond_model_input.shape[1],
|
||||
height=cond_model_input.shape[2],
|
||||
width=cond_model_input.shape[3],
|
||||
)
|
||||
packed_noisy_model_input = torch.cat([packed_noisy_model_input, packed_cond_input], dim=1)
|
||||
|
||||
# handle guidance
|
||||
if unwrap_model(transformer).config.guidance_embeds:
|
||||
# Kontext always has guidance
|
||||
guidance = None
|
||||
if has_guidance:
|
||||
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
|
||||
guidance = guidance.expand(model_input.shape[0])
|
||||
else:
|
||||
guidance = None
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
@@ -1862,6 +1995,8 @@ def main(args):
|
||||
img_ids=latent_image_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
if has_image_input:
|
||||
model_pred = model_pred[:, : orig_inp_shape[1]]
|
||||
model_pred = FluxKontextPipeline._unpack_latents(
|
||||
model_pred,
|
||||
height=model_input.shape[2] * vae_scale_factor,
|
||||
@@ -1970,6 +2105,8 @@ def main(args):
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if has_image_input and args.validation_image:
|
||||
pipeline_args.update({"image": load_image(args.validation_image)})
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
@@ -2030,6 +2167,8 @@ def main(args):
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
if has_image_input and args.validation_image:
|
||||
pipeline_args.update({"image": load_image(args.validation_image)})
|
||||
images = log_validation(
|
||||
pipeline=pipeline,
|
||||
args=args,
|
||||
|
||||
@@ -837,11 +837,6 @@ def main(args):
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
if any(k in name for k in NORM_LAYER_PREFIXES):
|
||||
param.requires_grad = True
|
||||
|
||||
if args.lora_layers is not None:
|
||||
if args.lora_layers != "all-linear":
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
@@ -879,6 +874,11 @@ def main(args):
|
||||
)
|
||||
flux_transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
if any(k in name for k in NORM_LAYER_PREFIXES):
|
||||
param.requires_grad = True
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
|
||||
@@ -95,7 +95,6 @@ TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
# "extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
|
||||
@@ -34,9 +34,11 @@ from .utils import (
|
||||
|
||||
_import_structure = {
|
||||
"configuration_utils": ["ConfigMixin"],
|
||||
"guiders": [],
|
||||
"hooks": [],
|
||||
"loaders": ["FromOriginalModelMixin"],
|
||||
"models": [],
|
||||
"modular_pipelines": [],
|
||||
"pipelines": [],
|
||||
"quantizers.quantization_config": [],
|
||||
"schedulers": [],
|
||||
@@ -130,12 +132,29 @@ except OptionalDependencyNotAvailable:
|
||||
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]
|
||||
|
||||
else:
|
||||
_import_structure["guiders"].extend(
|
||||
[
|
||||
"AdaptiveProjectedGuidance",
|
||||
"AutoGuidance",
|
||||
"ClassifierFreeGuidance",
|
||||
"ClassifierFreeZeroStarGuidance",
|
||||
"PerturbedAttentionGuidance",
|
||||
"SkipLayerGuidance",
|
||||
"SmoothedEnergyGuidance",
|
||||
"TangentialClassifierFreeGuidance",
|
||||
]
|
||||
)
|
||||
_import_structure["hooks"].extend(
|
||||
[
|
||||
"FasterCacheConfig",
|
||||
"FirstBlockCacheConfig",
|
||||
"HookRegistry",
|
||||
"LayerSkipConfig",
|
||||
"PyramidAttentionBroadcastConfig",
|
||||
"SmoothedEnergyGuidanceConfig",
|
||||
"apply_faster_cache",
|
||||
"apply_first_block_cache",
|
||||
"apply_layer_skip",
|
||||
"apply_pyramid_attention_broadcast",
|
||||
]
|
||||
)
|
||||
@@ -219,6 +238,14 @@ else:
|
||||
"WanVACETransformer3DModel",
|
||||
]
|
||||
)
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"ComponentsManager",
|
||||
"ComponentSpec",
|
||||
"ModularPipeline",
|
||||
"ModularPipelineBlocks",
|
||||
]
|
||||
)
|
||||
_import_structure["optimization"] = [
|
||||
"get_constant_schedule",
|
||||
"get_constant_schedule_with_warmup",
|
||||
@@ -331,6 +358,12 @@ except OptionalDependencyNotAvailable:
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["modular_pipelines"].extend(
|
||||
[
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLModularPipeline",
|
||||
]
|
||||
)
|
||||
_import_structure["pipelines"].extend(
|
||||
[
|
||||
"AllegroPipeline",
|
||||
@@ -381,6 +414,7 @@ else:
|
||||
"FluxFillPipeline",
|
||||
"FluxImg2ImgPipeline",
|
||||
"FluxInpaintPipeline",
|
||||
"FluxKontextInpaintPipeline",
|
||||
"FluxKontextPipeline",
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
@@ -542,6 +576,7 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_opencv_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -748,11 +783,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .guiders import (
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
)
|
||||
from .hooks import (
|
||||
FasterCacheConfig,
|
||||
FirstBlockCacheConfig,
|
||||
HookRegistry,
|
||||
LayerSkipConfig,
|
||||
PyramidAttentionBroadcastConfig,
|
||||
SmoothedEnergyGuidanceConfig,
|
||||
apply_faster_cache,
|
||||
apply_first_block_cache,
|
||||
apply_layer_skip,
|
||||
apply_pyramid_attention_broadcast,
|
||||
)
|
||||
from .models import (
|
||||
@@ -832,6 +882,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
WanTransformer3DModel,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
from .modular_pipelines import (
|
||||
ComponentsManager,
|
||||
ComponentSpec,
|
||||
ModularPipeline,
|
||||
ModularPipelineBlocks,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -928,6 +984,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_pipelines import (
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLModularPipeline,
|
||||
)
|
||||
from .pipelines import (
|
||||
AllegroPipeline,
|
||||
AltDiffusionImg2ImgPipeline,
|
||||
@@ -975,6 +1035,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxFillPipeline,
|
||||
FluxImg2ImgPipeline,
|
||||
FluxInpaintPipeline,
|
||||
FluxKontextInpaintPipeline,
|
||||
FluxKontextPipeline,
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
|
||||
134
src/diffusers/commands/custom_blocks.py
Normal file
134
src/diffusers/commands/custom_blocks.py
Normal file
@@ -0,0 +1,134 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Usage example:
|
||||
TODO
|
||||
"""
|
||||
|
||||
import ast
|
||||
import importlib.util
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseDiffusersCLICommand
|
||||
|
||||
|
||||
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"]
|
||||
CONFIG = "config.json"
|
||||
|
||||
|
||||
def conversion_command_factory(args: Namespace):
|
||||
return CustomBlocksCommand(args.block_module_name, args.block_class_name)
|
||||
|
||||
|
||||
class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
@staticmethod
|
||||
def register_subcommand(parser: ArgumentParser):
|
||||
conversion_parser = parser.add_parser("custom_blocks")
|
||||
conversion_parser.add_argument(
|
||||
"--block_module_name",
|
||||
type=str,
|
||||
default="block.py",
|
||||
help="Module filename in which the custom block will be implemented.",
|
||||
)
|
||||
conversion_parser.add_argument(
|
||||
"--block_class_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Name of the custom block. If provided None, we will try to infer it.",
|
||||
)
|
||||
conversion_parser.set_defaults(func=conversion_command_factory)
|
||||
|
||||
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None):
|
||||
self.logger = logging.get_logger("diffusers-cli/custom_blocks")
|
||||
self.block_module_name = Path(block_module_name)
|
||||
self.block_class_name = block_class_name
|
||||
|
||||
def run(self):
|
||||
# determine the block to be saved.
|
||||
out = self._get_class_names(self.block_module_name)
|
||||
classes_found = list({cls for cls, _ in out})
|
||||
|
||||
if self.block_class_name is not None:
|
||||
child_class, parent_class = self._choose_block(out, self.block_class_name)
|
||||
if child_class is None and parent_class is None:
|
||||
raise ValueError(
|
||||
"`block_class_name` could not be retrieved. Available classes from "
|
||||
f"{self.block_module_name}:\n{classes_found}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Found classes: {classes_found} will be using {classes_found[0]}. "
|
||||
"If this needs to be changed, re-run the command specifying `block_class_name`."
|
||||
)
|
||||
child_class, parent_class = out[0][0], out[0][1]
|
||||
|
||||
# dynamically get the custom block and initialize it to call `save_pretrained` in the current directory.
|
||||
# the user is responsible for running it, so I guess that is safe?
|
||||
module_name = f"__dynamic__{self.block_module_name.stem}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name))
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
getattr(module, child_class)().save_pretrained(os.getcwd())
|
||||
|
||||
# or, we could create it manually.
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
if cls == chosen:
|
||||
return cls, base
|
||||
return None, None
|
||||
|
||||
def _get_class_names(self, file_path):
|
||||
source = file_path.read_text(encoding="utf-8")
|
||||
try:
|
||||
tree = ast.parse(source, filename=file_path)
|
||||
except SyntaxError as e:
|
||||
raise ValueError(f"Could not parse {file_path!r}: {e}") from e
|
||||
|
||||
results: list[tuple[str, str]] = []
|
||||
for node in tree.body:
|
||||
if not isinstance(node, ast.ClassDef):
|
||||
continue
|
||||
|
||||
# extract all base names for this class
|
||||
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None]
|
||||
|
||||
# for each allowed base that appears in the class's bases, emit a tuple
|
||||
for allowed in EXPECTED_PARENT_CLASSES:
|
||||
if allowed in base_names:
|
||||
results.append((node.name, allowed))
|
||||
|
||||
return results
|
||||
|
||||
def _get_base_name(self, node: ast.expr):
|
||||
if isinstance(node, ast.Name):
|
||||
return node.id
|
||||
elif isinstance(node, ast.Attribute):
|
||||
val = self._get_base_name(node.value)
|
||||
return f"{val}.{node.attr}" if val else node.attr
|
||||
return None
|
||||
|
||||
def _create_automap(self, parent_class, child_class):
|
||||
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1]
|
||||
auto_map = {f"{parent_class}": f"{module}.{child_class}"}
|
||||
return {"auto_map": auto_map}
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from .custom_blocks import CustomBlocksCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .fp16_safetensors import FP16SafetensorsCommand
|
||||
|
||||
@@ -26,6 +27,7 @@ def main():
|
||||
# Register commands
|
||||
EnvironmentCommand.register_subcommand(commands_parser)
|
||||
FP16SafetensorsCommand.register_subcommand(commands_parser)
|
||||
CustomBlocksCommand.register_subcommand(commands_parser)
|
||||
|
||||
# Let's go
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -176,6 +176,7 @@ class ConfigMixin:
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
@@ -183,6 +184,7 @@ class ConfigMixin:
|
||||
token=token,
|
||||
commit_message=commit_message,
|
||||
create_pr=create_pr,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -601,6 +603,10 @@ class ConfigMixin:
|
||||
value = value.tolist()
|
||||
elif isinstance(value, Path):
|
||||
value = value.as_posix()
|
||||
elif hasattr(value, "to_dict") and callable(value.to_dict):
|
||||
value = value.to_dict()
|
||||
elif isinstance(value, list):
|
||||
value = [to_json_saveable(v) for v in value]
|
||||
return value
|
||||
|
||||
if "quantization_config" in config_dict:
|
||||
|
||||
39
src/diffusers/guiders/__init__.py
Normal file
39
src/diffusers/guiders/__init__.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Union
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
|
||||
from .auto_guidance import AutoGuidance
|
||||
from .classifier_free_guidance import ClassifierFreeGuidance
|
||||
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
|
||||
from .perturbed_attention_guidance import PerturbedAttentionGuidance
|
||||
from .skip_layer_guidance import SkipLayerGuidance
|
||||
from .smoothed_energy_guidance import SmoothedEnergyGuidance
|
||||
from .tangential_classifier_free_guidance import TangentialClassifierFreeGuidance
|
||||
|
||||
GuiderType = Union[
|
||||
AdaptiveProjectedGuidance,
|
||||
AutoGuidance,
|
||||
ClassifierFreeGuidance,
|
||||
ClassifierFreeZeroStarGuidance,
|
||||
PerturbedAttentionGuidance,
|
||||
SkipLayerGuidance,
|
||||
SmoothedEnergyGuidance,
|
||||
TangentialClassifierFreeGuidance,
|
||||
]
|
||||
188
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
188
src/diffusers/guiders/adaptive_projected_guidance.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AdaptiveProjectedGuidance(BaseGuidance):
|
||||
"""
|
||||
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
|
||||
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
|
||||
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
adaptive_projected_guidance_momentum: Optional[float] = None,
|
||||
adaptive_projected_guidance_rescale: float = 15.0,
|
||||
eta: float = 1.0,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
|
||||
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
|
||||
self.eta = eta
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
self.momentum_buffer = None
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self._step == 0:
|
||||
if self.adaptive_projected_guidance_momentum is not None:
|
||||
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_apg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(
|
||||
pred_cond,
|
||||
pred_uncond,
|
||||
self.guidance_scale,
|
||||
self.momentum_buffer,
|
||||
self.eta,
|
||||
self.adaptive_projected_guidance_rescale,
|
||||
self.use_original_formulation,
|
||||
)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_apg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_apg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
class MomentumBuffer:
|
||||
def __init__(self, momentum: float):
|
||||
self.momentum = momentum
|
||||
self.running_average = 0
|
||||
|
||||
def update(self, update_value: torch.Tensor):
|
||||
new_average = self.momentum * self.running_average
|
||||
self.running_average = update_value + new_average
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: torch.Tensor,
|
||||
guidance_scale: float,
|
||||
momentum_buffer: Optional[MomentumBuffer] = None,
|
||||
eta: float = 1.0,
|
||||
norm_threshold: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
):
|
||||
diff = pred_cond - pred_uncond
|
||||
dim = [-i for i in range(1, len(diff.shape))]
|
||||
|
||||
if momentum_buffer is not None:
|
||||
momentum_buffer.update(diff)
|
||||
diff = momentum_buffer.running_average
|
||||
|
||||
if norm_threshold > 0:
|
||||
ones = torch.ones_like(diff)
|
||||
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
|
||||
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
|
||||
diff = diff * scale_factor
|
||||
|
||||
v0, v1 = diff.double(), pred_cond.double()
|
||||
v1 = torch.nn.functional.normalize(v1, dim=dim)
|
||||
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
|
||||
v0_orthogonal = v0 - v0_parallel
|
||||
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
|
||||
normalized_update = diff_orthogonal + eta * diff_parallel
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
pred = pred + guidance_scale * normalized_update
|
||||
|
||||
return pred
|
||||
190
src/diffusers/guiders/auto_guidance.py
Normal file
190
src/diffusers/guiders/auto_guidance.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class AutoGuidance(BaseGuidance):
|
||||
"""
|
||||
AutoGuidance: https://huggingface.co/papers/2406.02507
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
auto_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided.
|
||||
auto_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
dropout (`float`, *optional*):
|
||||
The dropout probability for autoguidance on the enabled skip layers (either with `auto_guidance_layers` or
|
||||
`auto_guidance_config`). If not provided, the dropout probability will be set to 1.0.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
auto_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
auto_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
dropout: Optional[float] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.auto_guidance_layers = auto_guidance_layers
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self.dropout = dropout
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if auto_guidance_layers is None and auto_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `auto_guidance_layers` or `auto_guidance_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if auto_guidance_layers is not None and auto_guidance_config is not None:
|
||||
raise ValueError("Only one of `auto_guidance_layers` or `auto_guidance_config` can be provided.")
|
||||
if (dropout is None and auto_guidance_layers is not None) or (
|
||||
dropout is not None and auto_guidance_layers is None
|
||||
):
|
||||
raise ValueError("`dropout` must be provided if `auto_guidance_layers` is provided.")
|
||||
|
||||
if auto_guidance_layers is not None:
|
||||
if isinstance(auto_guidance_layers, int):
|
||||
auto_guidance_layers = [auto_guidance_layers]
|
||||
if not isinstance(auto_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_layers` to be an int or a list of ints, but got {type(auto_guidance_layers)}."
|
||||
)
|
||||
auto_guidance_config = [
|
||||
LayerSkipConfig(layer, fqn="auto", dropout=dropout) for layer in auto_guidance_layers
|
||||
]
|
||||
|
||||
if isinstance(auto_guidance_config, dict):
|
||||
auto_guidance_config = LayerSkipConfig.from_dict(auto_guidance_config)
|
||||
|
||||
if isinstance(auto_guidance_config, LayerSkipConfig):
|
||||
auto_guidance_config = [auto_guidance_config]
|
||||
|
||||
if not isinstance(auto_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `auto_guidance_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(auto_guidance_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(auto_guidance_config), None), dict):
|
||||
auto_guidance_config = [LayerSkipConfig.from_dict(config) for config in auto_guidance_config]
|
||||
|
||||
self.auto_guidance_config = auto_guidance_config
|
||||
self._auto_guidance_hook_names = [f"AutoGuidance_{i}" for i in range(len(self.auto_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name, config in zip(self._auto_guidance_hook_names, self.auto_guidance_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_ag_enabled() and self.is_unconditional:
|
||||
for name in self._auto_guidance_hook_names:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
registry.remove_hook(name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_ag_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_ag_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_ag_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
141
src/diffusers/guiders/classifier_free_guidance.py
Normal file
141
src/diffusers/guiders/classifier_free_guidance.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free guidance (CFG): https://huggingface.co/papers/2207.12598
|
||||
|
||||
CFG is a technique used to improve generation quality and condition-following in diffusion models. It works by
|
||||
jointly training a model on both conditional and unconditional data, and using a weighted sum of the two during
|
||||
inference. This allows the model to tradeoff between generation quality and sample diversity. The original paper
|
||||
proposes scaling and shifting the conditional distribution based on the difference between conditional and
|
||||
unconditional predictions. [x_pred = x_cond + scale * (x_cond - x_uncond)]
|
||||
|
||||
Diffusers implemented the scaling and shifting on the unconditional prediction instead based on the [Imagen
|
||||
paper](https://huggingface.co/papers/2205.11487), which is equivalent to what the original paper proposed in
|
||||
theory. [x_pred = x_uncond + scale * (x_cond - x_uncond)]
|
||||
|
||||
The intution behind the original formulation can be thought of as moving the conditional distribution estimates
|
||||
further away from the unconditional distribution estimates, while the diffusers-native implementation can be
|
||||
thought of as moving the unconditional distribution towards the conditional distribution estimates to get rid of
|
||||
the unconditional predictions (usually negative features like "bad quality, bad anotomy, watermarks", etc.)
|
||||
|
||||
The `use_original_formulation` argument can be set to `True` to use the original CFG formulation mentioned in the
|
||||
paper. By default, we use the diffusers-native implementation that has been in the codebase for a long time.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
152
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
152
src/diffusers/guiders/classifier_free_zero_star_guidance.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class ClassifierFreeZeroStarGuidance(BaseGuidance):
|
||||
"""
|
||||
Classifier-free Zero* (CFG-Zero*): https://huggingface.co/papers/2503.18886
|
||||
|
||||
This is an implementation of the Classifier-Free Zero* guidance technique, which is a variant of classifier-free
|
||||
guidance. It proposes zero initialization of the noise predictions for the first few steps of the diffusion
|
||||
process, and also introduces an optimal rescaling factor for the noise predictions, which can help in improving the
|
||||
quality of generated images.
|
||||
|
||||
The authors of the paper suggest setting zero initialization in the first 4% of the inference steps.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
zero_init_steps (`int`, defaults to `1`):
|
||||
The number of inference steps for which the noise predictions are zeroed out (see Section 4.2).
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
zero_init_steps: int = 1,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.zero_init_steps = zero_init_steps
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if self._step < self.zero_init_steps:
|
||||
pred = torch.zeros_like(pred_cond)
|
||||
elif not self._is_cfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred_cond_flat = pred_cond.flatten(1)
|
||||
pred_uncond_flat = pred_uncond.flatten(1)
|
||||
alpha = cfg_zero_star_scale(pred_cond_flat, pred_uncond_flat)
|
||||
alpha = alpha.view(-1, *(1,) * (len(pred_cond.shape) - 1))
|
||||
pred_uncond = pred_uncond * alpha
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def cfg_zero_star_scale(cond: torch.Tensor, uncond: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
|
||||
cond_dtype = cond.dtype
|
||||
cond = cond.float()
|
||||
uncond = uncond.float()
|
||||
dot_product = torch.sum(cond * uncond, dim=1, keepdim=True)
|
||||
squared_norm = torch.sum(uncond**2, dim=1, keepdim=True) + eps
|
||||
# st_star = v_cond^T * v_uncond / ||v_uncond||^2
|
||||
scale = dot_product / squared_norm
|
||||
return scale.to(dtype=cond_dtype)
|
||||
309
src/diffusers/guiders/guider_utils.py
Normal file
309
src/diffusers/guiders/guider_utils.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..configuration_utils import ConfigMixin
|
||||
from ..utils import PushToHubMixin, get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
GUIDER_CONFIG_NAME = "guider_config.json"
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
r"""Base class providing the skeleton for implementing guidance techniques."""
|
||||
|
||||
config_name = GUIDER_CONFIG_NAME
|
||||
_input_predictions = None
|
||||
_identifier_key = "__guidance_identifier__"
|
||||
|
||||
def __init__(self, start: float = 0.0, stop: float = 1.0):
|
||||
self._start = start
|
||||
self._stop = stop
|
||||
self._step: int = None
|
||||
self._num_inference_steps: int = None
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: Dict[str, Union[str, Tuple[str, str]]] = None
|
||||
self._enabled = True
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
if not (start <= stop <= 1.0):
|
||||
raise ValueError(f"Expected `stop` to be between {start} and 1.0, but got {stop}.")
|
||||
|
||||
if self._input_predictions is None or not isinstance(self._input_predictions, list):
|
||||
raise ValueError(
|
||||
"`_input_predictions` must be a list of required prediction names for the guidance technique."
|
||||
)
|
||||
|
||||
def disable(self):
|
||||
self._enabled = False
|
||||
|
||||
def enable(self):
|
||||
self._enabled = True
|
||||
|
||||
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
|
||||
self._step = step
|
||||
self._num_inference_steps = num_inference_steps
|
||||
self._timestep = timestep
|
||||
self._count_prepared = 0
|
||||
|
||||
def set_input_fields(self, **kwargs: Dict[str, Union[str, Tuple[str, str]]]) -> None:
|
||||
"""
|
||||
Set the input fields for the guidance technique. The input fields are used to specify the names of the returned
|
||||
attributes containing the prepared data after `prepare_inputs` is called. The prepared data is obtained from
|
||||
the values of the provided keyword arguments to this method.
|
||||
|
||||
Args:
|
||||
**kwargs (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation.
|
||||
|
||||
If a string is provided, it will be used as the conditional data (or unconditional if used with a
|
||||
guidance method that requires it). If a tuple of length 2 is provided, the first element must be the
|
||||
conditional data identifier and the second element must be the unconditional data identifier or None.
|
||||
|
||||
Example:
|
||||
```
|
||||
data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>}
|
||||
|
||||
BaseGuidance.set_input_fields(
|
||||
latents="latents",
|
||||
prompt_embeds=("prompt_embeds", "negative_prompt_embeds"),
|
||||
)
|
||||
```
|
||||
"""
|
||||
for key, value in kwargs.items():
|
||||
is_string = isinstance(value, str)
|
||||
is_tuple_of_str_with_len_2 = (
|
||||
isinstance(value, tuple) and len(value) == 2 and all(isinstance(v, str) for v in value)
|
||||
)
|
||||
if not (is_string or is_tuple_of_str_with_len_2):
|
||||
raise ValueError(
|
||||
f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got {type(value)} for key {key}."
|
||||
)
|
||||
self._input_fields = kwargs
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Prepares the models for the guidance technique on a given batch of data. This method should be overridden in
|
||||
subclasses to implement specific model preparation logic.
|
||||
"""
|
||||
self._count_prepared += 1
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
"""
|
||||
Cleans up the models for the guidance technique after a given batch of data. This method should be overridden
|
||||
in subclasses to implement specific model cleanup logic. It is useful for removing any hooks or other stateful
|
||||
modifications made during `prepare_models`.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
|
||||
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
|
||||
|
||||
def __call__(self, data: List["BlockState"]) -> Any:
|
||||
if not all(hasattr(d, "noise_pred") for d in data):
|
||||
raise ValueError("Expected all data to have `noise_pred` attribute.")
|
||||
if len(data) != self.num_conditions:
|
||||
raise ValueError(
|
||||
f"Expected {self.num_conditions} data items, but got {len(data)}. Please check the input data."
|
||||
)
|
||||
forward_inputs = {getattr(d, self._identifier_key): d.noise_pred for d in data}
|
||||
return self.forward(**forward_inputs)
|
||||
|
||||
def forward(self, *args, **kwargs) -> Any:
|
||||
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
|
||||
|
||||
@property
|
||||
def is_unconditional(self) -> bool:
|
||||
return not self.is_conditional
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
|
||||
|
||||
@classmethod
|
||||
def _prepare_batch(
|
||||
cls,
|
||||
input_fields: Dict[str, Union[str, Tuple[str, str]]],
|
||||
data: "BlockState",
|
||||
tuple_index: int,
|
||||
identifier: str,
|
||||
) -> "BlockState":
|
||||
"""
|
||||
Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of the
|
||||
`BaseGuidance` class. It prepares the batch based on the provided tuple index.
|
||||
|
||||
Args:
|
||||
input_fields (`Dict[str, Union[str, Tuple[str, str]]]`):
|
||||
A dictionary where the keys are the names of the fields that will be used to store the data once it is
|
||||
prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, which is used
|
||||
to look up the required data provided for preparation. If a string is provided, it will be used as the
|
||||
conditional data (or unconditional if used with a guidance method that requires it). If a tuple of
|
||||
length 2 is provided, the first element must be the conditional data identifier and the second element
|
||||
must be the unconditional data identifier or None.
|
||||
data (`BlockState`):
|
||||
The input data to be prepared.
|
||||
tuple_index (`int`):
|
||||
The index to use when accessing input fields that are tuples.
|
||||
|
||||
Returns:
|
||||
`BlockState`: The prepared batch of data.
|
||||
"""
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
if input_fields is None:
|
||||
raise ValueError(
|
||||
"Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs."
|
||||
)
|
||||
data_batch = {}
|
||||
for key, value in input_fields.items():
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
data_batch[key] = getattr(data, value)
|
||||
elif isinstance(value, tuple):
|
||||
data_batch[key] = getattr(data, value[tuple_index])
|
||||
else:
|
||||
# We've already checked that value is a string or a tuple of strings with length 2
|
||||
pass
|
||||
except AttributeError:
|
||||
logger.debug(f"`data` does not have attribute(s) {value}, skipping.")
|
||||
data_batch[cls._identifier_key] = identifier
|
||||
return BlockState(**data_batch)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None,
|
||||
subfolder: Optional[str] = None,
|
||||
return_unused_kwargs=False,
|
||||
**kwargs,
|
||||
) -> Self:
|
||||
r"""
|
||||
Instantiate a guider from a pre-defined JSON configuration file in a local directory or Hub repository.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the guider configuration
|
||||
saved with [`~BaseGuidance.save_pretrained`].
|
||||
subfolder (`str`, *optional*):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
Whether kwargs that are not consumed by the Python class should be returned or not.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
|
||||
<Tip>
|
||||
|
||||
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
||||
`huggingface-cli login`. You can also activate the special
|
||||
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
||||
firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
config, kwargs, commit_hash = cls.load_config(
|
||||
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
return_unused_kwargs=True,
|
||||
return_commit_hash=True,
|
||||
**kwargs,
|
||||
)
|
||||
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a guider configuration object to a directory so that it can be reloaded using the
|
||||
[`~BaseGuidance.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
||||
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
||||
namespace).
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
|
||||
|
||||
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
||||
r"""
|
||||
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
||||
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
|
||||
Args:
|
||||
noise_cfg (`torch.Tensor`):
|
||||
The predicted noise tensor for the guided diffusion process.
|
||||
noise_pred_text (`torch.Tensor`):
|
||||
The predicted noise tensor for the text-guided diffusion process.
|
||||
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
||||
A rescale factor applied to the noise predictions.
|
||||
Returns:
|
||||
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
||||
"""
|
||||
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
||||
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
||||
# rescale the results from guidance (fixes overexposure)
|
||||
noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
|
||||
# mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
|
||||
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
|
||||
return noise_cfg
|
||||
271
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
271
src/diffusers/guiders/perturbed_attention_guidance.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from ..utils import get_logger
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class PerturbedAttentionGuidance(BaseGuidance):
|
||||
"""
|
||||
Perturbed Attention Guidance (PAG): https://huggingface.co/papers/2403.17377
|
||||
|
||||
The intution behind PAG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates. PAG was one of the first techniques to introduce the idea
|
||||
of using a worse version of the trained model for better guiding itself in the denoising process. It perturbs the
|
||||
attention scores of the latent stream by replacing the score matrix with an identity matrix for selectively chosen
|
||||
layers.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
PAG is implemented with similar implementation to SkipLayerGuidance due to overlap in the configuration parameters
|
||||
and implementation details.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
perturbed_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for perturbed attention guidance.
|
||||
perturbed_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance starts.
|
||||
perturbed_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which perturbed attention guidance stops.
|
||||
perturbed_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply perturbed attention guidance to. Can be a single integer or a list of integers.
|
||||
If not provided, `perturbed_guidance_config` must be provided.
|
||||
perturbed_guidance_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the perturbed attention guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `perturbed_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
# NOTE: The current implementation does not account for joint latent conditioning (text + image/video tokens in
|
||||
# the same latent stream). It assumes the entire latent is a single stream of visual tokens. It would be very
|
||||
# complex to support joint latent conditioning in a model-agnostic manner without specializing the implementation
|
||||
# for each model architecture.
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
perturbed_guidance_scale: float = 2.8,
|
||||
perturbed_guidance_start: float = 0.01,
|
||||
perturbed_guidance_stop: float = 0.2,
|
||||
perturbed_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
perturbed_guidance_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = perturbed_guidance_scale
|
||||
self.skip_layer_guidance_start = perturbed_guidance_start
|
||||
self.skip_layer_guidance_stop = perturbed_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if perturbed_guidance_config is None:
|
||||
if perturbed_guidance_layers is None:
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_layers` must be provided if `perturbed_guidance_config` is not specified."
|
||||
)
|
||||
perturbed_guidance_config = LayerSkipConfig(
|
||||
indices=perturbed_guidance_layers,
|
||||
fqn="auto",
|
||||
skip_attention=False,
|
||||
skip_attention_scores=True,
|
||||
skip_ff=False,
|
||||
)
|
||||
else:
|
||||
if perturbed_guidance_layers is not None:
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_layers` should not be provided if `perturbed_guidance_config` is specified."
|
||||
)
|
||||
|
||||
if isinstance(perturbed_guidance_config, dict):
|
||||
perturbed_guidance_config = LayerSkipConfig.from_dict(perturbed_guidance_config)
|
||||
|
||||
if isinstance(perturbed_guidance_config, LayerSkipConfig):
|
||||
perturbed_guidance_config = [perturbed_guidance_config]
|
||||
|
||||
if not isinstance(perturbed_guidance_config, list):
|
||||
raise ValueError(
|
||||
"`perturbed_guidance_config` must be a `LayerSkipConfig`, a list of `LayerSkipConfig`, or a dict that can be converted to a `LayerSkipConfig`."
|
||||
)
|
||||
elif isinstance(next(iter(perturbed_guidance_config), None), dict):
|
||||
perturbed_guidance_config = [LayerSkipConfig.from_dict(config) for config in perturbed_guidance_config]
|
||||
|
||||
for config in perturbed_guidance_config:
|
||||
if config.skip_attention or not config.skip_attention_scores or config.skip_ff:
|
||||
logger.warning(
|
||||
"Perturbed Attention Guidance is designed to perturb attention scores, so `skip_attention` should be False, `skip_attention_scores` should be True, and `skip_ff` should be False. "
|
||||
"Please check your configuration. Modifying the config to match the expected values."
|
||||
)
|
||||
config.skip_attention = False
|
||||
config.skip_attention_scores = True
|
||||
config.skip_ff = False
|
||||
|
||||
self.skip_layer_config = perturbed_guidance_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_models
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.cleanup_models
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.prepare_inputs
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.forward
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.is_conditional
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance.num_conditions
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_cfg_enabled
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
# Copied from diffusers.guiders.skip_layer_guidance.SkipLayerGuidance._is_slg_enabled
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
262
src/diffusers/guiders/skip_layer_guidance.py
Normal file
262
src/diffusers/guiders/skip_layer_guidance.py
Normal file
@@ -0,0 +1,262 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry, LayerSkipConfig
|
||||
from ..hooks.layer_skip import _apply_layer_skip_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SkipLayerGuidance(BaseGuidance):
|
||||
"""
|
||||
Skip Layer Guidance (SLG): https://github.com/Stability-AI/sd3.5
|
||||
|
||||
Spatio-Temporal Guidance (STG): https://huggingface.co/papers/2411.18664
|
||||
|
||||
SLG was introduced by StabilityAI for improving structure and anotomy coherence in generated images. It works by
|
||||
skipping the forward pass of specified transformer blocks during the denoising process on an additional conditional
|
||||
batch of data, apart from the conditional and unconditional batches already used in CFG
|
||||
([~guiders.classifier_free_guidance.ClassifierFreeGuidance]), and then scaling and shifting the CFG predictions
|
||||
based on the difference between conditional without skipping and conditional with skipping predictions.
|
||||
|
||||
The intution behind SLG can be thought of as moving the CFG predicted distribution estimates further away from
|
||||
worse versions of the conditional distribution estimates (because skipping layers is equivalent to using a worse
|
||||
version of the model for the conditional prediction).
|
||||
|
||||
STG is an improvement and follow-up work combining ideas from SLG, PAG and similar techniques for improving
|
||||
generation quality in video diffusion models.
|
||||
|
||||
Additional reading:
|
||||
- [Guiding a Diffusion Model with a Bad Version of Itself](https://huggingface.co/papers/2406.02507)
|
||||
|
||||
The values for `skip_layer_guidance_scale`, `skip_layer_guidance_start`, and `skip_layer_guidance_stop` are
|
||||
defaulted to the recommendations by StabilityAI for Stable Diffusion 3.5 Medium.
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
skip_layer_guidance_scale (`float`, defaults to `2.8`):
|
||||
The scale parameter for skip layer guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
skip_layer_guidance_start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance starts.
|
||||
skip_layer_guidance_stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which skip layer guidance stops.
|
||||
skip_layer_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply skip layer guidance to. Can be a single integer or a list of integers. If not
|
||||
provided, `skip_layer_config` must be provided. The recommended values are `[7, 8, 9]` for Stable Diffusion
|
||||
3.5 Medium.
|
||||
skip_layer_config (`LayerSkipConfig` or `List[LayerSkipConfig]`, *optional*):
|
||||
The configuration for the skip layer guidance. Can be a single `LayerSkipConfig` or a list of
|
||||
`LayerSkipConfig`. If not provided, `skip_layer_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
skip_layer_guidance_scale: float = 2.8,
|
||||
skip_layer_guidance_start: float = 0.01,
|
||||
skip_layer_guidance_stop: float = 0.2,
|
||||
skip_layer_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
skip_layer_config: Union[LayerSkipConfig, List[LayerSkipConfig], Dict[str, Any]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.skip_layer_guidance_scale = skip_layer_guidance_scale
|
||||
self.skip_layer_guidance_start = skip_layer_guidance_start
|
||||
self.skip_layer_guidance_stop = skip_layer_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= skip_layer_guidance_start < 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_start` to be between 0.0 and 1.0, but got {skip_layer_guidance_start}."
|
||||
)
|
||||
if not (skip_layer_guidance_start <= skip_layer_guidance_stop <= 1.0):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_stop` to be between 0.0 and 1.0, but got {skip_layer_guidance_stop}."
|
||||
)
|
||||
|
||||
if skip_layer_guidance_layers is None and skip_layer_config is None:
|
||||
raise ValueError(
|
||||
"Either `skip_layer_guidance_layers` or `skip_layer_config` must be provided to enable Skip Layer Guidance."
|
||||
)
|
||||
if skip_layer_guidance_layers is not None and skip_layer_config is not None:
|
||||
raise ValueError("Only one of `skip_layer_guidance_layers` or `skip_layer_config` can be provided.")
|
||||
|
||||
if skip_layer_guidance_layers is not None:
|
||||
if isinstance(skip_layer_guidance_layers, int):
|
||||
skip_layer_guidance_layers = [skip_layer_guidance_layers]
|
||||
if not isinstance(skip_layer_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_guidance_layers` to be an int or a list of ints, but got {type(skip_layer_guidance_layers)}."
|
||||
)
|
||||
skip_layer_config = [LayerSkipConfig(layer, fqn="auto") for layer in skip_layer_guidance_layers]
|
||||
|
||||
if isinstance(skip_layer_config, dict):
|
||||
skip_layer_config = LayerSkipConfig.from_dict(skip_layer_config)
|
||||
|
||||
if isinstance(skip_layer_config, LayerSkipConfig):
|
||||
skip_layer_config = [skip_layer_config]
|
||||
|
||||
if not isinstance(skip_layer_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `skip_layer_config` to be a LayerSkipConfig or a list of LayerSkipConfig, but got {type(skip_layer_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(skip_layer_config), None), dict):
|
||||
skip_layer_config = [LayerSkipConfig.from_dict(config) for config in skip_layer_config]
|
||||
|
||||
self.skip_layer_config = skip_layer_config
|
||||
self._skip_layer_hook_names = [f"SkipLayerGuidance_{i}" for i in range(len(self.skip_layer_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
self._count_prepared += 1
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._skip_layer_hook_names, self.skip_layer_config):
|
||||
_apply_layer_skip_hook(denoiser, config, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_slg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._skip_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_skip"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_skip: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_slg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_skip
|
||||
pred = pred + self.skip_layer_guidance_scale * shift
|
||||
elif not self._is_slg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_skip = pred_cond - pred_cond_skip
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.skip_layer_guidance_scale * shift_skip
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_slg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_slg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.skip_layer_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.skip_layer_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.skip_layer_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
251
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
251
src/diffusers/guiders/smoothed_energy_guidance.py
Normal file
@@ -0,0 +1,251 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from ..hooks import HookRegistry
|
||||
from ..hooks.smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig, _apply_smoothed_energy_guidance_hook
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class SmoothedEnergyGuidance(BaseGuidance):
|
||||
"""
|
||||
Smoothed Energy Guidance (SEG): https://huggingface.co/papers/2408.00760
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
||||
future without warning or guarantee of reproducibility. This implementation assumes:
|
||||
- Generated images are square (height == width)
|
||||
- The model does not combine different modalities together (e.g., text and image latent streams are not combined
|
||||
together such as Flux)
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
seg_guidance_scale (`float`, defaults to `3.0`):
|
||||
The scale parameter for smoothed energy guidance. Anatomy and structure coherence may improve with higher
|
||||
values, but it may also lead to overexposure and saturation.
|
||||
seg_blur_sigma (`float`, defaults to `9999999.0`):
|
||||
The amount by which we blur the attention weights. Setting this value greater than 9999.0 results in
|
||||
infinite blur, which means uniform queries. Controlling it exponentially is empirically effective.
|
||||
seg_blur_threshold_inf (`float`, defaults to `9999.0`):
|
||||
The threshold above which the blur is considered infinite.
|
||||
seg_guidance_start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance starts.
|
||||
seg_guidance_stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which smoothed energy guidance stops.
|
||||
seg_guidance_layers (`int` or `List[int]`, *optional*):
|
||||
The layer indices to apply smoothed energy guidance to. Can be a single integer or a list of integers. If
|
||||
not provided, `seg_guidance_config` must be provided. The recommended values are `[7, 8, 9]` for Stable
|
||||
Diffusion 3.5 Medium.
|
||||
seg_guidance_config (`SmoothedEnergyGuidanceConfig` or `List[SmoothedEnergyGuidanceConfig]`, *optional*):
|
||||
The configuration for the smoothed energy layer guidance. Can be a single `SmoothedEnergyGuidanceConfig` or
|
||||
a list of `SmoothedEnergyGuidanceConfig`. If not provided, `seg_guidance_layers` must be provided.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.01`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `0.2`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
seg_guidance_scale: float = 2.8,
|
||||
seg_blur_sigma: float = 9999999.0,
|
||||
seg_blur_threshold_inf: float = 9999.0,
|
||||
seg_guidance_start: float = 0.0,
|
||||
seg_guidance_stop: float = 1.0,
|
||||
seg_guidance_layers: Optional[Union[int, List[int]]] = None,
|
||||
seg_guidance_config: Union[SmoothedEnergyGuidanceConfig, List[SmoothedEnergyGuidanceConfig]] = None,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.seg_guidance_scale = seg_guidance_scale
|
||||
self.seg_blur_sigma = seg_blur_sigma
|
||||
self.seg_blur_threshold_inf = seg_blur_threshold_inf
|
||||
self.seg_guidance_start = seg_guidance_start
|
||||
self.seg_guidance_stop = seg_guidance_stop
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
if not (0.0 <= seg_guidance_start < 1.0):
|
||||
raise ValueError(f"Expected `seg_guidance_start` to be between 0.0 and 1.0, but got {seg_guidance_start}.")
|
||||
if not (seg_guidance_start <= seg_guidance_stop <= 1.0):
|
||||
raise ValueError(f"Expected `seg_guidance_stop` to be between 0.0 and 1.0, but got {seg_guidance_stop}.")
|
||||
|
||||
if seg_guidance_layers is None and seg_guidance_config is None:
|
||||
raise ValueError(
|
||||
"Either `seg_guidance_layers` or `seg_guidance_config` must be provided to enable Smoothed Energy Guidance."
|
||||
)
|
||||
if seg_guidance_layers is not None and seg_guidance_config is not None:
|
||||
raise ValueError("Only one of `seg_guidance_layers` or `seg_guidance_config` can be provided.")
|
||||
|
||||
if seg_guidance_layers is not None:
|
||||
if isinstance(seg_guidance_layers, int):
|
||||
seg_guidance_layers = [seg_guidance_layers]
|
||||
if not isinstance(seg_guidance_layers, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_layers` to be an int or a list of ints, but got {type(seg_guidance_layers)}."
|
||||
)
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig(layer, fqn="auto") for layer in seg_guidance_layers]
|
||||
|
||||
if isinstance(seg_guidance_config, dict):
|
||||
seg_guidance_config = SmoothedEnergyGuidanceConfig.from_dict(seg_guidance_config)
|
||||
|
||||
if isinstance(seg_guidance_config, SmoothedEnergyGuidanceConfig):
|
||||
seg_guidance_config = [seg_guidance_config]
|
||||
|
||||
if not isinstance(seg_guidance_config, list):
|
||||
raise ValueError(
|
||||
f"Expected `seg_guidance_config` to be a SmoothedEnergyGuidanceConfig or a list of SmoothedEnergyGuidanceConfig, but got {type(seg_guidance_config)}."
|
||||
)
|
||||
elif isinstance(next(iter(seg_guidance_config), None), dict):
|
||||
seg_guidance_config = [SmoothedEnergyGuidanceConfig.from_dict(config) for config in seg_guidance_config]
|
||||
|
||||
self.seg_guidance_config = seg_guidance_config
|
||||
self._seg_layer_hook_names = [f"SmoothedEnergyGuidance_{i}" for i in range(len(self.seg_guidance_config))]
|
||||
|
||||
def prepare_models(self, denoiser: torch.nn.Module) -> None:
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
for name, config in zip(self._seg_layer_hook_names, self.seg_guidance_config):
|
||||
_apply_smoothed_energy_guidance_hook(denoiser, config, self.seg_blur_sigma, name=name)
|
||||
|
||||
def cleanup_models(self, denoiser: torch.nn.Module):
|
||||
if self._is_seg_enabled() and self.is_conditional and self._count_prepared > 1:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
|
||||
# Remove the hooks after inference
|
||||
for hook_name in self._seg_layer_hook_names:
|
||||
registry.remove_hook(hook_name, recurse=True)
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
if self.num_conditions == 1:
|
||||
tuple_indices = [0]
|
||||
input_predictions = ["pred_cond"]
|
||||
elif self.num_conditions == 2:
|
||||
tuple_indices = [0, 1]
|
||||
input_predictions = (
|
||||
["pred_cond", "pred_uncond"] if self._is_cfg_enabled() else ["pred_cond", "pred_cond_seg"]
|
||||
)
|
||||
else:
|
||||
tuple_indices = [0, 1, 0]
|
||||
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_cond: torch.Tensor,
|
||||
pred_uncond: Optional[torch.Tensor] = None,
|
||||
pred_cond_seg: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_cfg_enabled() and not self._is_seg_enabled():
|
||||
pred = pred_cond
|
||||
elif not self._is_cfg_enabled():
|
||||
shift = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_cond_seg
|
||||
pred = pred + self.seg_guidance_scale * shift
|
||||
elif not self._is_seg_enabled():
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift
|
||||
else:
|
||||
shift = pred_cond - pred_uncond
|
||||
shift_seg = pred_cond - pred_cond_seg
|
||||
pred = pred_cond if self.use_original_formulation else pred_uncond
|
||||
pred = pred + self.guidance_scale * shift + self.seg_guidance_scale * shift_seg
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._count_prepared == 1 or self._count_prepared == 3
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_cfg_enabled():
|
||||
num_conditions += 1
|
||||
if self._is_seg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_cfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
def _is_seg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self.seg_guidance_start * self._num_inference_steps)
|
||||
skip_stop_step = int(self.seg_guidance_stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step < self._step < skip_stop_step
|
||||
|
||||
is_zero = math.isclose(self.seg_guidance_scale, 0.0)
|
||||
|
||||
return is_within_range and not is_zero
|
||||
143
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
143
src/diffusers/guiders/tangential_classifier_free_guidance.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import register_to_config
|
||||
from .guider_utils import BaseGuidance, rescale_noise_cfg
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modular_pipelines.modular_pipeline import BlockState
|
||||
|
||||
|
||||
class TangentialClassifierFreeGuidance(BaseGuidance):
|
||||
"""
|
||||
Tangential Classifier Free Guidance (TCFG): https://huggingface.co/papers/2503.18137
|
||||
|
||||
Args:
|
||||
guidance_scale (`float`, defaults to `7.5`):
|
||||
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
|
||||
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
|
||||
deterioration of image quality.
|
||||
guidance_rescale (`float`, defaults to `0.0`):
|
||||
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
|
||||
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
||||
Flawed](https://huggingface.co/papers/2305.08891).
|
||||
use_original_formulation (`bool`, defaults to `False`):
|
||||
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
|
||||
we use the diffusers-native implementation that has been in the codebase for a long time. See
|
||||
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
|
||||
start (`float`, defaults to `0.0`):
|
||||
The fraction of the total number of denoising steps after which guidance starts.
|
||||
stop (`float`, defaults to `1.0`):
|
||||
The fraction of the total number of denoising steps after which guidance stops.
|
||||
"""
|
||||
|
||||
_input_predictions = ["pred_cond", "pred_uncond"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
guidance_scale: float = 7.5,
|
||||
guidance_rescale: float = 0.0,
|
||||
use_original_formulation: bool = False,
|
||||
start: float = 0.0,
|
||||
stop: float = 1.0,
|
||||
):
|
||||
super().__init__(start, stop)
|
||||
|
||||
self.guidance_scale = guidance_scale
|
||||
self.guidance_rescale = guidance_rescale
|
||||
self.use_original_formulation = use_original_formulation
|
||||
|
||||
def prepare_inputs(
|
||||
self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None
|
||||
) -> List["BlockState"]:
|
||||
if input_fields is None:
|
||||
input_fields = self._input_fields
|
||||
|
||||
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
|
||||
data_batches = []
|
||||
for i in range(self.num_conditions):
|
||||
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
|
||||
data_batches.append(data_batch)
|
||||
return data_batches
|
||||
|
||||
def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
pred = None
|
||||
|
||||
if not self._is_tcfg_enabled():
|
||||
pred = pred_cond
|
||||
else:
|
||||
pred = normalized_guidance(pred_cond, pred_uncond, self.guidance_scale, self.use_original_formulation)
|
||||
|
||||
if self.guidance_rescale > 0.0:
|
||||
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
|
||||
|
||||
return pred, {}
|
||||
|
||||
@property
|
||||
def is_conditional(self) -> bool:
|
||||
return self._num_outputs_prepared == 1
|
||||
|
||||
@property
|
||||
def num_conditions(self) -> int:
|
||||
num_conditions = 1
|
||||
if self._is_tcfg_enabled():
|
||||
num_conditions += 1
|
||||
return num_conditions
|
||||
|
||||
def _is_tcfg_enabled(self) -> bool:
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
is_within_range = True
|
||||
if self._num_inference_steps is not None:
|
||||
skip_start_step = int(self._start * self._num_inference_steps)
|
||||
skip_stop_step = int(self._stop * self._num_inference_steps)
|
||||
is_within_range = skip_start_step <= self._step < skip_stop_step
|
||||
|
||||
is_close = False
|
||||
if self.use_original_formulation:
|
||||
is_close = math.isclose(self.guidance_scale, 0.0)
|
||||
else:
|
||||
is_close = math.isclose(self.guidance_scale, 1.0)
|
||||
|
||||
return is_within_range and not is_close
|
||||
|
||||
|
||||
def normalized_guidance(
|
||||
pred_cond: torch.Tensor, pred_uncond: torch.Tensor, guidance_scale: float, use_original_formulation: bool = False
|
||||
) -> torch.Tensor:
|
||||
cond_dtype = pred_cond.dtype
|
||||
preds = torch.stack([pred_cond, pred_uncond], dim=1).float()
|
||||
preds = preds.flatten(2)
|
||||
U, S, Vh = torch.linalg.svd(preds, full_matrices=False)
|
||||
Vh_modified = Vh.clone()
|
||||
Vh_modified[:, 1] = 0
|
||||
|
||||
uncond_flat = pred_uncond.reshape(pred_uncond.size(0), 1, -1).float()
|
||||
x_Vh = torch.matmul(uncond_flat, Vh.transpose(-2, -1))
|
||||
x_Vh_V = torch.matmul(x_Vh, Vh_modified)
|
||||
pred_uncond = x_Vh_V.reshape(pred_uncond.shape).to(cond_dtype)
|
||||
|
||||
pred = pred_cond if use_original_formulation else pred_uncond
|
||||
shift = pred_cond - pred_uncond
|
||||
pred = pred + guidance_scale * shift
|
||||
|
||||
return pred
|
||||
@@ -1,9 +1,26 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ..utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .faster_cache import FasterCacheConfig, apply_faster_cache
|
||||
from .first_block_cache import FirstBlockCacheConfig, apply_first_block_cache
|
||||
from .group_offloading import apply_group_offloading
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
from .layer_skip import LayerSkipConfig, apply_layer_skip
|
||||
from .layerwise_casting import apply_layerwise_casting, apply_layerwise_casting_hook
|
||||
from .pyramid_attention_broadcast import PyramidAttentionBroadcastConfig, apply_pyramid_attention_broadcast
|
||||
from .smoothed_energy_guidance_utils import SmoothedEnergyGuidanceConfig
|
||||
|
||||
43
src/diffusers/hooks/_common.py
Normal file
43
src/diffusers/hooks/_common.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..models.attention import FeedForward, LuminaFeedForward
|
||||
from ..models.attention_processor import Attention, MochiAttention
|
||||
|
||||
|
||||
_ATTENTION_CLASSES = (Attention, MochiAttention)
|
||||
_FEEDFORWARD_CLASSES = (FeedForward, LuminaFeedForward)
|
||||
|
||||
_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "single_transformer_blocks", "layers")
|
||||
_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS = ("temporal_transformer_blocks",)
|
||||
_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS = ("blocks", "transformer_blocks", "layers")
|
||||
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS = tuple(
|
||||
{
|
||||
*_SPATIAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_TEMPORAL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
*_CROSS_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_submodule_from_fqn(module: torch.nn.Module, fqn: str) -> Optional[torch.nn.Module]:
|
||||
for submodule_name, submodule in module.named_modules():
|
||||
if submodule_name == fqn:
|
||||
return submodule
|
||||
return None
|
||||
264
src/diffusers/hooks/_helpers.py
Normal file
264
src/diffusers/hooks/_helpers.py
Normal file
@@ -0,0 +1,264 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Type
|
||||
|
||||
|
||||
@dataclass
|
||||
class AttentionProcessorMetadata:
|
||||
skip_processor_output_fn: Callable[[Any], Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerBlockMetadata:
|
||||
return_hidden_states_index: int = None
|
||||
return_encoder_hidden_states_index: int = None
|
||||
|
||||
_cls: Type = None
|
||||
_cached_parameter_indices: Dict[str, int] = None
|
||||
|
||||
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
if identifier in kwargs:
|
||||
return kwargs[identifier]
|
||||
if self._cached_parameter_indices is not None:
|
||||
return args[self._cached_parameter_indices[identifier]]
|
||||
if self._cls is None:
|
||||
raise ValueError("Model class is not set for metadata.")
|
||||
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
|
||||
parameters = parameters[1:] # skip `self`
|
||||
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
|
||||
if identifier not in self._cached_parameter_indices:
|
||||
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
|
||||
index = self._cached_parameter_indices[identifier]
|
||||
if index >= len(args):
|
||||
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
|
||||
return args[index]
|
||||
|
||||
|
||||
class AttentionProcessorRegistry:
|
||||
_registry = {}
|
||||
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
||||
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
||||
# import errors because of the models imported in this file.
|
||||
_is_registered = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
|
||||
cls._register()
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
|
||||
cls._register()
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
if cls._is_registered:
|
||||
return
|
||||
cls._is_registered = True
|
||||
_register_attention_processors_metadata()
|
||||
|
||||
|
||||
class TransformerBlockRegistry:
|
||||
_registry = {}
|
||||
# TODO(aryan): this is only required for the time being because we need to do the registrations
|
||||
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
|
||||
# import errors because of the models imported in this file.
|
||||
_is_registered = False
|
||||
|
||||
@classmethod
|
||||
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
|
||||
cls._register()
|
||||
metadata._cls = model_class
|
||||
cls._registry[model_class] = metadata
|
||||
|
||||
@classmethod
|
||||
def get(cls, model_class: Type) -> TransformerBlockMetadata:
|
||||
cls._register()
|
||||
if model_class not in cls._registry:
|
||||
raise ValueError(f"Model class {model_class} not registered.")
|
||||
return cls._registry[model_class]
|
||||
|
||||
@classmethod
|
||||
def _register(cls):
|
||||
if cls._is_registered:
|
||||
return
|
||||
cls._is_registered = True
|
||||
_register_transformer_blocks_metadata()
|
||||
|
||||
|
||||
def _register_attention_processors_metadata():
|
||||
from ..models.attention_processor import AttnProcessor2_0
|
||||
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
|
||||
|
||||
# AttnProcessor2_0
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=AttnProcessor2_0,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4AttnProcessor
|
||||
AttentionProcessorRegistry.register(
|
||||
model_class=CogView4AttnProcessor,
|
||||
metadata=AttentionProcessorMetadata(
|
||||
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _register_transformer_blocks_metadata():
|
||||
from ..models.attention import BasicTransformerBlock
|
||||
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
|
||||
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
|
||||
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
|
||||
from ..models.transformers.transformer_hunyuan_video import (
|
||||
HunyuanVideoSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
HunyuanVideoTokenReplaceTransformerBlock,
|
||||
HunyuanVideoTransformerBlock,
|
||||
)
|
||||
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
|
||||
from ..models.transformers.transformer_mochi import MochiTransformerBlock
|
||||
from ..models.transformers.transformer_wan import WanTransformerBlock
|
||||
|
||||
# BasicTransformerBlock
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=BasicTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# CogVideoX
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogVideoXBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# CogView4
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=CogView4TransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Flux
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=FluxSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=1,
|
||||
return_encoder_hidden_states_index=0,
|
||||
),
|
||||
)
|
||||
|
||||
# HunyuanVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# LTXVideo
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=LTXVideoTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
# Mochi
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=MochiTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=1,
|
||||
),
|
||||
)
|
||||
|
||||
# Wan
|
||||
TransformerBlockRegistry.register(
|
||||
model_class=WanTransformerBlock,
|
||||
metadata=TransformerBlockMetadata(
|
||||
return_hidden_states_index=0,
|
||||
return_encoder_hidden_states_index=None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# fmt: off
|
||||
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
|
||||
hidden_states = kwargs.get("hidden_states", None)
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
if hidden_states is None and len(args) > 0:
|
||||
hidden_states = args[0]
|
||||
if encoder_hidden_states is None and len(args) > 1:
|
||||
encoder_hidden_states = args[1]
|
||||
return hidden_states, encoder_hidden_states
|
||||
|
||||
|
||||
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
|
||||
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
|
||||
# fmt: on
|
||||
227
src/diffusers/hooks/first_block_cache.py
Normal file
227
src/diffusers/hooks/first_block_cache.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS
|
||||
from ._helpers import TransformerBlockRegistry
|
||||
from .hooks import BaseState, HookRegistry, ModelHook, StateManager
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_FBC_LEADER_BLOCK_HOOK = "fbc_leader_block_hook"
|
||||
_FBC_BLOCK_HOOK = "fbc_block_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FirstBlockCacheConfig:
|
||||
r"""
|
||||
Configuration for [First Block
|
||||
Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching).
|
||||
|
||||
Args:
|
||||
threshold (`float`, defaults to `0.05`):
|
||||
The threshold to determine whether or not a forward pass through all layers of the model is required. A
|
||||
higher threshold usually results in a forward pass through a lower number of layers and faster inference,
|
||||
but might lead to poorer generation quality. A lower threshold may not result in significant generation
|
||||
speedup. The threshold is compared against the absmean difference of the residuals between the current and
|
||||
cached outputs from the first transformer block. If the difference is below the threshold, the forward pass
|
||||
is skipped.
|
||||
"""
|
||||
|
||||
threshold: float = 0.05
|
||||
|
||||
|
||||
class FBCSharedBlockState(BaseState):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.head_block_output: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.head_block_residual: torch.Tensor = None
|
||||
self.tail_block_residuals: Union[torch.Tensor, Tuple[torch.Tensor, ...]] = None
|
||||
self.should_compute: bool = True
|
||||
|
||||
def reset(self):
|
||||
self.tail_block_residuals = None
|
||||
self.should_compute = True
|
||||
|
||||
|
||||
class FBCHeadBlockHook(ModelHook):
|
||||
_is_stateful = True
|
||||
|
||||
def __init__(self, state_manager: StateManager, threshold: float):
|
||||
self.state_manager = state_manager
|
||||
self.threshold = threshold
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
is_output_tuple = isinstance(output, tuple)
|
||||
|
||||
if is_output_tuple:
|
||||
hidden_states_residual = output[self._metadata.return_hidden_states_index] - original_hidden_states
|
||||
else:
|
||||
hidden_states_residual = output - original_hidden_states
|
||||
|
||||
shared_state: FBCSharedBlockState = self.state_manager.get_state()
|
||||
hidden_states = encoder_hidden_states = None
|
||||
should_compute = self._should_compute_remaining_blocks(hidden_states_residual)
|
||||
shared_state.should_compute = should_compute
|
||||
|
||||
if not should_compute:
|
||||
# Apply caching
|
||||
if is_output_tuple:
|
||||
hidden_states = (
|
||||
shared_state.tail_block_residuals[0] + output[self._metadata.return_hidden_states_index]
|
||||
)
|
||||
else:
|
||||
hidden_states = shared_state.tail_block_residuals[0] + output
|
||||
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
assert is_output_tuple
|
||||
encoder_hidden_states = (
|
||||
shared_state.tail_block_residuals[1] + output[self._metadata.return_encoder_hidden_states_index]
|
||||
)
|
||||
|
||||
if is_output_tuple:
|
||||
return_output = [None] * len(output)
|
||||
return_output[self._metadata.return_hidden_states_index] = hidden_states
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = encoder_hidden_states
|
||||
return_output = tuple(return_output)
|
||||
else:
|
||||
return_output = hidden_states
|
||||
output = return_output
|
||||
else:
|
||||
if is_output_tuple:
|
||||
head_block_output = [None] * len(output)
|
||||
head_block_output[0] = output[self._metadata.return_hidden_states_index]
|
||||
head_block_output[1] = output[self._metadata.return_encoder_hidden_states_index]
|
||||
else:
|
||||
head_block_output = output
|
||||
shared_state.head_block_output = head_block_output
|
||||
shared_state.head_block_residual = hidden_states_residual
|
||||
|
||||
return output
|
||||
|
||||
def reset_state(self, module):
|
||||
self.state_manager.reset()
|
||||
return module
|
||||
|
||||
@torch.compiler.disable
|
||||
def _should_compute_remaining_blocks(self, hidden_states_residual: torch.Tensor) -> bool:
|
||||
shared_state = self.state_manager.get_state()
|
||||
if shared_state.head_block_residual is None:
|
||||
return True
|
||||
prev_hidden_states_residual = shared_state.head_block_residual
|
||||
absmean = (hidden_states_residual - prev_hidden_states_residual).abs().mean()
|
||||
prev_hidden_states_absmean = prev_hidden_states_residual.abs().mean()
|
||||
diff = (absmean / prev_hidden_states_absmean).item()
|
||||
return diff > self.threshold
|
||||
|
||||
|
||||
class FBCBlockHook(ModelHook):
|
||||
def __init__(self, state_manager: StateManager, is_tail: bool = False):
|
||||
super().__init__()
|
||||
self.state_manager = state_manager
|
||||
self.is_tail = is_tail
|
||||
self._metadata = None
|
||||
|
||||
def initialize_hook(self, module):
|
||||
unwrapped_module = unwrap_module(module)
|
||||
self._metadata = TransformerBlockRegistry.get(unwrapped_module.__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
original_encoder_hidden_states = None
|
||||
if self._metadata.return_encoder_hidden_states_index is not None:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
|
||||
shared_state = self.state_manager.get_state()
|
||||
|
||||
if shared_state.should_compute:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
if self.is_tail:
|
||||
hidden_states_residual = encoder_hidden_states_residual = None
|
||||
if isinstance(output, tuple):
|
||||
hidden_states_residual = (
|
||||
output[self._metadata.return_hidden_states_index] - shared_state.head_block_output[0]
|
||||
)
|
||||
encoder_hidden_states_residual = (
|
||||
output[self._metadata.return_encoder_hidden_states_index] - shared_state.head_block_output[1]
|
||||
)
|
||||
else:
|
||||
hidden_states_residual = output - shared_state.head_block_output
|
||||
shared_state.tail_block_residuals = (hidden_states_residual, encoder_hidden_states_residual)
|
||||
return output
|
||||
|
||||
if original_encoder_hidden_states is None:
|
||||
return_output = original_hidden_states
|
||||
else:
|
||||
return_output = [None, None]
|
||||
return_output[self._metadata.return_hidden_states_index] = original_hidden_states
|
||||
return_output[self._metadata.return_encoder_hidden_states_index] = original_encoder_hidden_states
|
||||
return_output = tuple(return_output)
|
||||
return return_output
|
||||
|
||||
|
||||
def apply_first_block_cache(module: torch.nn.Module, config: FirstBlockCacheConfig) -> None:
|
||||
state_manager = StateManager(FBCSharedBlockState, (), {})
|
||||
remaining_blocks = []
|
||||
|
||||
for name, submodule in module.named_children():
|
||||
if name not in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS or not isinstance(submodule, torch.nn.ModuleList):
|
||||
continue
|
||||
for index, block in enumerate(submodule):
|
||||
remaining_blocks.append((f"{name}.{index}", block))
|
||||
|
||||
head_block_name, head_block = remaining_blocks.pop(0)
|
||||
tail_block_name, tail_block = remaining_blocks.pop(-1)
|
||||
|
||||
logger.debug(f"Applying FBCHeadBlockHook to '{head_block_name}'")
|
||||
_apply_fbc_head_block_hook(head_block, state_manager, config.threshold)
|
||||
|
||||
for name, block in remaining_blocks:
|
||||
logger.debug(f"Applying FBCBlockHook to '{name}'")
|
||||
_apply_fbc_block_hook(block, state_manager)
|
||||
|
||||
logger.debug(f"Applying FBCBlockHook to tail block '{tail_block_name}'")
|
||||
_apply_fbc_block_hook(tail_block, state_manager, is_tail=True)
|
||||
|
||||
|
||||
def _apply_fbc_head_block_hook(block: torch.nn.Module, state_manager: StateManager, threshold: float) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCHeadBlockHook(state_manager, threshold)
|
||||
registry.register_hook(hook, _FBC_LEADER_BLOCK_HOOK)
|
||||
|
||||
|
||||
def _apply_fbc_block_hook(block: torch.nn.Module, state_manager: StateManager, is_tail: bool = False) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = FBCBlockHook(state_manager, is_tail)
|
||||
registry.register_hook(hook, _FBC_BLOCK_HOOK)
|
||||
@@ -14,6 +14,8 @@
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Set, Tuple, Union
|
||||
|
||||
import safetensors.torch
|
||||
@@ -46,6 +48,24 @@ _SUPPORTED_PYTORCH_LAYERS = (
|
||||
# fmt: on
|
||||
|
||||
|
||||
class GroupOffloadingType(str, Enum):
|
||||
BLOCK_LEVEL = "block_level"
|
||||
LEAF_LEVEL = "leaf_level"
|
||||
|
||||
|
||||
@dataclass
|
||||
class GroupOffloadingConfig:
|
||||
onload_device: torch.device
|
||||
offload_device: torch.device
|
||||
offload_type: GroupOffloadingType
|
||||
non_blocking: bool
|
||||
record_stream: bool
|
||||
low_cpu_mem_usage: bool
|
||||
num_blocks_per_group: Optional[int] = None
|
||||
offload_to_disk_path: Optional[str] = None
|
||||
stream: Optional[Union[torch.cuda.Stream, torch.Stream]] = None
|
||||
|
||||
|
||||
class ModuleGroup:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -288,9 +308,12 @@ class GroupOffloadingHook(ModelHook):
|
||||
|
||||
_is_stateful = False
|
||||
|
||||
def __init__(self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None) -> None:
|
||||
def __init__(
|
||||
self, group: ModuleGroup, next_group: Optional[ModuleGroup] = None, *, config: GroupOffloadingConfig
|
||||
) -> None:
|
||||
self.group = group
|
||||
self.next_group = next_group
|
||||
self.config = config
|
||||
|
||||
def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
|
||||
if self.group.offload_leader == module:
|
||||
@@ -436,7 +459,7 @@ def apply_group_offloading(
|
||||
module: torch.nn.Module,
|
||||
onload_device: torch.device,
|
||||
offload_device: torch.device = torch.device("cpu"),
|
||||
offload_type: str = "block_level",
|
||||
offload_type: Union[str, GroupOffloadingType] = "block_level",
|
||||
num_blocks_per_group: Optional[int] = None,
|
||||
non_blocking: bool = False,
|
||||
use_stream: bool = False,
|
||||
@@ -478,7 +501,7 @@ def apply_group_offloading(
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_device (`torch.device`, defaults to `torch.device("cpu")`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU. Default is CPU.
|
||||
offload_type (`str`, defaults to "block_level"):
|
||||
offload_type (`str` or `GroupOffloadingType`, defaults to "block_level"):
|
||||
The type of offloading to be applied. Can be one of "block_level" or "leaf_level". Default is
|
||||
"block_level".
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
@@ -521,6 +544,8 @@ def apply_group_offloading(
|
||||
```
|
||||
"""
|
||||
|
||||
offload_type = GroupOffloadingType(offload_type)
|
||||
|
||||
stream = None
|
||||
if use_stream:
|
||||
if torch.cuda.is_available():
|
||||
@@ -532,84 +557,45 @@ def apply_group_offloading(
|
||||
|
||||
if not use_stream and record_stream:
|
||||
raise ValueError("`record_stream` cannot be True when `use_stream=False`.")
|
||||
if offload_type == GroupOffloadingType.BLOCK_LEVEL and num_blocks_per_group is None:
|
||||
raise ValueError("`num_blocks_per_group` must be provided when using `offload_type='block_level'.")
|
||||
|
||||
_raise_error_if_accelerate_model_or_sequential_hook_present(module)
|
||||
|
||||
if offload_type == "block_level":
|
||||
if num_blocks_per_group is None:
|
||||
raise ValueError("num_blocks_per_group must be provided when using offload_type='block_level'.")
|
||||
config = GroupOffloadingConfig(
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type=offload_type,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
)
|
||||
_apply_group_offloading(module, config)
|
||||
|
||||
_apply_group_offloading_block_level(
|
||||
module=module,
|
||||
num_blocks_per_group=num_blocks_per_group,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
elif offload_type == "leaf_level":
|
||||
_apply_group_offloading_leaf_level(
|
||||
module=module,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
)
|
||||
|
||||
def _apply_group_offloading(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
if config.offload_type == GroupOffloadingType.BLOCK_LEVEL:
|
||||
_apply_group_offloading_block_level(module, config)
|
||||
elif config.offload_type == GroupOffloadingType.LEAF_LEVEL:
|
||||
_apply_group_offloading_leaf_level(module, config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported offload_type: {offload_type}")
|
||||
assert False
|
||||
|
||||
|
||||
def _apply_group_offloading_block_level(
|
||||
module: torch.nn.Module,
|
||||
num_blocks_per_group: int,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_block_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of torch.nn.ModuleList or torch.nn.Sequential blocks. In comparison to
|
||||
the "leaf_level" offloading, which is more fine-grained, this offloading is done at the top-level blocks.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream`or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
if stream is not None and num_blocks_per_group != 1:
|
||||
|
||||
if config.stream is not None and config.num_blocks_per_group != 1:
|
||||
logger.warning(
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {num_blocks_per_group=}. Setting it to 1."
|
||||
f"Using streams is only supported for num_blocks_per_group=1. Got {config.num_blocks_per_group=}. Setting it to 1."
|
||||
)
|
||||
num_blocks_per_group = 1
|
||||
config.num_blocks_per_group = 1
|
||||
|
||||
# Create module groups for ModuleList and Sequential blocks
|
||||
modules_with_group_offloading = set()
|
||||
@@ -621,19 +607,19 @@ def _apply_group_offloading_block_level(
|
||||
modules_with_group_offloading.add(name)
|
||||
continue
|
||||
|
||||
for i in range(0, len(submodule), num_blocks_per_group):
|
||||
current_modules = submodule[i : i + num_blocks_per_group]
|
||||
for i in range(0, len(submodule), config.num_blocks_per_group):
|
||||
current_modules = submodule[i : i + config.num_blocks_per_group]
|
||||
group = ModuleGroup(
|
||||
modules=current_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=current_modules[-1],
|
||||
onload_leader=current_modules[0],
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
matched_module_groups.append(group)
|
||||
@@ -643,7 +629,7 @@ def _apply_group_offloading_block_level(
|
||||
# Apply group offloading hooks to the module groups
|
||||
for i, group in enumerate(matched_module_groups):
|
||||
for group_module in group.modules:
|
||||
_apply_group_offloading_hook(group_module, group, None)
|
||||
_apply_group_offloading_hook(group_module, group, None, config=config)
|
||||
|
||||
# Parameters and Buffers of the top-level module need to be offloaded/onloaded separately
|
||||
# when the forward pass of this module is called. This is because the top-level module is not
|
||||
@@ -658,9 +644,9 @@ def _apply_group_offloading_block_level(
|
||||
unmatched_modules = [unmatched_module for _, unmatched_module in unmatched_modules]
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=unmatched_modules,
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=parameters,
|
||||
@@ -670,54 +656,19 @@ def _apply_group_offloading_block_level(
|
||||
record_stream=False,
|
||||
onload_self=True,
|
||||
)
|
||||
if stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None)
|
||||
if config.stream is None:
|
||||
_apply_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
else:
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_leaf_level(
|
||||
module: torch.nn.Module,
|
||||
offload_device: torch.device,
|
||||
onload_device: torch.device,
|
||||
non_blocking: bool,
|
||||
stream: Union[torch.cuda.Stream, torch.Stream, None] = None,
|
||||
record_stream: Optional[bool] = False,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
offload_to_disk_path: Optional[str] = None,
|
||||
) -> None:
|
||||
def _apply_group_offloading_leaf_level(module: torch.nn.Module, config: GroupOffloadingConfig) -> None:
|
||||
r"""
|
||||
This function applies offloading to groups of leaf modules in a torch.nn.Module. This method has minimal memory
|
||||
requirements. However, it can be slower compared to other offloading methods due to the excessive number of device
|
||||
synchronizations. When using devices that support streams to overlap data transfer and computation, this method can
|
||||
reduce memory usage without any performance degradation.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module to which group offloading is applied.
|
||||
offload_device (`torch.device`):
|
||||
The device to which the group of modules are offloaded. This should typically be the CPU.
|
||||
onload_device (`torch.device`):
|
||||
The device to which the group of modules are onloaded.
|
||||
offload_to_disk_path (`str`, *optional*, defaults to `None`):
|
||||
The path to the directory where parameters will be offloaded. Setting this option can be useful in limited
|
||||
RAM environment settings where a reasonable speed-memory trade-off is desired.
|
||||
non_blocking (`bool`):
|
||||
If True, offloading and onloading is done asynchronously. This can be useful for overlapping computation
|
||||
and data transfer.
|
||||
stream (`torch.cuda.Stream` or `torch.Stream`, *optional*):
|
||||
If provided, offloading and onloading is done asynchronously using the provided stream. This can be useful
|
||||
for overlapping computation and data transfer.
|
||||
record_stream (`bool`, defaults to `False`): When enabled with `use_stream`, it marks the current tensor
|
||||
as having been used by this stream. It is faster at the expense of slightly more memory usage. Refer to the
|
||||
[PyTorch official docs](https://pytorch.org/docs/stable/generated/torch.Tensor.record_stream.html) more
|
||||
details.
|
||||
low_cpu_mem_usage (`bool`, defaults to `False`):
|
||||
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
|
||||
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
|
||||
the CPU memory is a bottleneck but may counteract the benefits of using streams.
|
||||
"""
|
||||
|
||||
# Create module groups for leaf modules and apply group offloading hooks
|
||||
modules_with_group_offloading = set()
|
||||
for name, submodule in module.named_modules():
|
||||
@@ -725,18 +676,18 @@ def _apply_group_offloading_leaf_level(
|
||||
continue
|
||||
group = ModuleGroup(
|
||||
modules=[submodule],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=submodule,
|
||||
onload_leader=submodule,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(submodule, group, None)
|
||||
_apply_group_offloading_hook(submodule, group, None, config=config)
|
||||
modules_with_group_offloading.add(name)
|
||||
|
||||
# Parameters and Buffers at all non-leaf levels need to be offloaded/onloaded separately when the forward pass
|
||||
@@ -767,33 +718,32 @@ def _apply_group_offloading_leaf_level(
|
||||
parameters = parent_to_parameters.get(name, [])
|
||||
buffers = parent_to_buffers.get(name, [])
|
||||
parent_module = module_dict[name]
|
||||
assert getattr(parent_module, "_diffusers_hook", None) is None
|
||||
group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_leader=parent_module,
|
||||
onload_leader=parent_module,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
parameters=parameters,
|
||||
buffers=buffers,
|
||||
non_blocking=non_blocking,
|
||||
stream=stream,
|
||||
record_stream=record_stream,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
non_blocking=config.non_blocking,
|
||||
stream=config.stream,
|
||||
record_stream=config.record_stream,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_group_offloading_hook(parent_module, group, None)
|
||||
_apply_group_offloading_hook(parent_module, group, None, config=config)
|
||||
|
||||
if stream is not None:
|
||||
if config.stream is not None:
|
||||
# When using streams, we need to know the layer execution order for applying prefetching (to overlap data transfer
|
||||
# and computation). Since we don't know the order beforehand, we apply a lazy prefetching hook that will find the
|
||||
# execution order and apply prefetching in the correct order.
|
||||
unmatched_group = ModuleGroup(
|
||||
modules=[],
|
||||
offload_device=offload_device,
|
||||
onload_device=onload_device,
|
||||
offload_to_disk_path=offload_to_disk_path,
|
||||
offload_device=config.offload_device,
|
||||
onload_device=config.onload_device,
|
||||
offload_to_disk_path=config.offload_to_disk_path,
|
||||
offload_leader=module,
|
||||
onload_leader=module,
|
||||
parameters=None,
|
||||
@@ -801,23 +751,25 @@ def _apply_group_offloading_leaf_level(
|
||||
non_blocking=False,
|
||||
stream=None,
|
||||
record_stream=False,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
low_cpu_mem_usage=config.low_cpu_mem_usage,
|
||||
onload_self=True,
|
||||
)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None)
|
||||
_apply_lazy_group_offloading_hook(module, unmatched_group, None, config=config)
|
||||
|
||||
|
||||
def _apply_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
|
||||
@@ -825,13 +777,15 @@ def _apply_lazy_group_offloading_hook(
|
||||
module: torch.nn.Module,
|
||||
group: ModuleGroup,
|
||||
next_group: Optional[ModuleGroup] = None,
|
||||
*,
|
||||
config: GroupOffloadingConfig,
|
||||
) -> None:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
|
||||
# We may have already registered a group offloading hook if the module had a torch.nn.Parameter whose parent
|
||||
# is the current module. In such cases, we don't want to overwrite the existing group offloading hook.
|
||||
if registry.get_hook(_GROUP_OFFLOADING) is None:
|
||||
hook = GroupOffloadingHook(group, next_group)
|
||||
hook = GroupOffloadingHook(group, next_group, config=config)
|
||||
registry.register_hook(hook, _GROUP_OFFLOADING)
|
||||
|
||||
lazy_prefetch_hook = LazyPrefetchGroupOffloadingHook()
|
||||
@@ -898,15 +852,48 @@ def _raise_error_if_accelerate_model_or_sequential_hook_present(module: torch.nn
|
||||
)
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
def _get_top_level_group_offload_hook(module: torch.nn.Module) -> Optional[GroupOffloadingHook]:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return True
|
||||
return False
|
||||
if hasattr(submodule, "_diffusers_hook"):
|
||||
group_offloading_hook = submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING)
|
||||
if group_offloading_hook is not None:
|
||||
return group_offloading_hook
|
||||
return None
|
||||
|
||||
|
||||
def _is_group_offload_enabled(module: torch.nn.Module) -> bool:
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
return top_level_group_offload_hook is not None
|
||||
|
||||
|
||||
def _get_group_onload_device(module: torch.nn.Module) -> torch.device:
|
||||
for submodule in module.modules():
|
||||
if hasattr(submodule, "_diffusers_hook") and submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING) is not None:
|
||||
return submodule._diffusers_hook.get_hook(_GROUP_OFFLOADING).group.onload_device
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
if top_level_group_offload_hook is not None:
|
||||
return top_level_group_offload_hook.config.onload_device
|
||||
raise ValueError("Group offloading is not enabled for the provided module.")
|
||||
|
||||
|
||||
def _maybe_remove_and_reapply_group_offloading(module: torch.nn.Module) -> None:
|
||||
r"""
|
||||
Removes the group offloading hook from the module and re-applies it. This is useful when the module has been
|
||||
modified in-place and the group offloading hook references-to-tensors needs to be updated. The in-place
|
||||
modification can happen in a number of ways, for example, fusing QKV or unloading/loading LoRAs on-the-fly.
|
||||
|
||||
In this implementation, we make an assumption that group offloading has only been applied at the top-level module,
|
||||
and therefore all submodules have the same onload and offload devices. If this assumption is not true, say in the
|
||||
case where user has applied group offloading at multiple levels, this function will not work as expected.
|
||||
|
||||
There is some performance penalty associated with doing this when non-default streams are used, because we need to
|
||||
retrace the execution order of the layers with `LazyPrefetchGroupOffloadingHook`.
|
||||
"""
|
||||
top_level_group_offload_hook = _get_top_level_group_offload_hook(module)
|
||||
|
||||
if top_level_group_offload_hook is None:
|
||||
return
|
||||
|
||||
registry = HookRegistry.check_if_exists_or_initialize(module)
|
||||
registry.remove_hook(_GROUP_OFFLOADING, recurse=True)
|
||||
registry.remove_hook(_LAYER_EXECUTION_TRACKER, recurse=True)
|
||||
registry.remove_hook(_LAZY_PREFETCH_GROUP_OFFLOADING, recurse=True)
|
||||
|
||||
_apply_group_offloading(module, top_level_group_offload_hook.config)
|
||||
|
||||
@@ -18,11 +18,44 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class BaseState:
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError(
|
||||
"BaseState::reset is not implemented. Please implement this method in the derived class."
|
||||
)
|
||||
|
||||
|
||||
class StateManager:
|
||||
def __init__(self, state_cls: BaseState, init_args=None, init_kwargs=None):
|
||||
self._state_cls = state_cls
|
||||
self._init_args = init_args if init_args is not None else ()
|
||||
self._init_kwargs = init_kwargs if init_kwargs is not None else {}
|
||||
self._state_cache = {}
|
||||
self._current_context = None
|
||||
|
||||
def get_state(self):
|
||||
if self._current_context is None:
|
||||
raise ValueError("No context is set. Please set a context before retrieving the state.")
|
||||
if self._current_context not in self._state_cache.keys():
|
||||
self._state_cache[self._current_context] = self._state_cls(*self._init_args, **self._init_kwargs)
|
||||
return self._state_cache[self._current_context]
|
||||
|
||||
def set_context(self, name: str) -> None:
|
||||
self._current_context = name
|
||||
|
||||
def reset(self, *args, **kwargs) -> None:
|
||||
for name, state in list(self._state_cache.items()):
|
||||
state.reset(*args, **kwargs)
|
||||
self._state_cache.pop(name)
|
||||
self._current_context = None
|
||||
|
||||
|
||||
class ModelHook:
|
||||
r"""
|
||||
A hook that contains callbacks to be executed just before and after the forward method of a model.
|
||||
@@ -99,6 +132,14 @@ class ModelHook:
|
||||
raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
|
||||
return module
|
||||
|
||||
def _set_context(self, module: torch.nn.Module, name: str) -> None:
|
||||
# Iterate over all attributes of the hook to see if any of them have the type `StateManager`. If so, call `set_context` on them.
|
||||
for attr_name in dir(self):
|
||||
attr = getattr(self, attr_name)
|
||||
if isinstance(attr, StateManager):
|
||||
attr.set_context(name)
|
||||
return module
|
||||
|
||||
|
||||
class HookFunctionReference:
|
||||
def __init__(self) -> None:
|
||||
@@ -211,9 +252,10 @@ class HookRegistry:
|
||||
hook.reset_state(self._module_ref)
|
||||
|
||||
if recurse:
|
||||
for module_name, module in self._module_ref.named_modules():
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook.reset_stateful_hooks(recurse=False)
|
||||
|
||||
@@ -223,6 +265,19 @@ class HookRegistry:
|
||||
module._diffusers_hook = cls(module)
|
||||
return module._diffusers_hook
|
||||
|
||||
def _set_context(self, name: Optional[str] = None) -> None:
|
||||
for hook_name in reversed(self._hook_order):
|
||||
hook = self.hooks[hook_name]
|
||||
if hook._is_stateful:
|
||||
hook._set_context(self._module_ref, name)
|
||||
|
||||
for module_name, module in unwrap_module(self._module_ref).named_modules():
|
||||
if module_name == "":
|
||||
continue
|
||||
module = unwrap_module(module)
|
||||
if hasattr(module, "_diffusers_hook"):
|
||||
module._diffusers_hook._set_context(name)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
registry_repr = ""
|
||||
for i, hook_name in enumerate(self._hook_order):
|
||||
|
||||
254
src/diffusers/hooks/layer_skip.py
Normal file
254
src/diffusers/hooks/layer_skip.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from ._common import (
|
||||
_ALL_TRANSFORMER_BLOCK_IDENTIFIERS,
|
||||
_ATTENTION_CLASSES,
|
||||
_FEEDFORWARD_CLASSES,
|
||||
_get_submodule_from_fqn,
|
||||
)
|
||||
from ._helpers import AttentionProcessorRegistry, TransformerBlockRegistry
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_LAYER_SKIP_HOOK = "layer_skip_hook"
|
||||
|
||||
|
||||
# Aryan/YiYi TODO: we need to make guider class a config mixin so I think this is not needed
|
||||
# either remove or make it serializable
|
||||
@dataclass
|
||||
class LayerSkipConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
skip_attention (`bool`, defaults to `True`):
|
||||
Whether to skip attention blocks.
|
||||
skip_ff (`bool`, defaults to `True`):
|
||||
Whether to skip feed-forward blocks.
|
||||
skip_attention_scores (`bool`, defaults to `False`):
|
||||
Whether to skip attention score computation in the attention blocks. This is equivalent to using `value`
|
||||
projections as the output of scaled dot product attention.
|
||||
dropout (`float`, defaults to `1.0`):
|
||||
The dropout probability for dropping the outputs of the skipped layers. By default, this is set to `1.0`,
|
||||
meaning that the outputs of the skipped layers are completely ignored. If set to `0.0`, the outputs of the
|
||||
skipped layers are fully retained, which is equivalent to not skipping any layers.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
skip_attention: bool = True
|
||||
skip_attention_scores: bool = False
|
||||
skip_ff: bool = True
|
||||
dropout: float = 1.0
|
||||
|
||||
def __post_init__(self):
|
||||
if not (0 <= self.dropout <= 1):
|
||||
raise ValueError(f"Expected `dropout` to be between 0.0 and 1.0, but got {self.dropout}.")
|
||||
if not math.isclose(self.dropout, 1.0) and self.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "LayerSkipConfig":
|
||||
return LayerSkipConfig(**data)
|
||||
|
||||
|
||||
class AttentionScoreSkipFunctionMode(torch.overrides.TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if func is torch.nn.functional.scaled_dot_product_attention:
|
||||
value = kwargs.get("value", None)
|
||||
if value is None:
|
||||
value = args[2]
|
||||
return value
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class AttentionProcessorSkipHook(ModelHook):
|
||||
def __init__(self, skip_processor_output_fn: Callable, skip_attention_scores: bool = False, dropout: float = 1.0):
|
||||
self.skip_processor_output_fn = skip_processor_output_fn
|
||||
self.skip_attention_scores = skip_attention_scores
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if self.skip_attention_scores:
|
||||
if not math.isclose(self.dropout, 1.0):
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
with AttentionScoreSkipFunctionMode():
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
else:
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = self.skip_processor_output_fn(module, *args, **kwargs)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class FeedForwardSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
output = kwargs.get("hidden_states", None)
|
||||
if output is None:
|
||||
output = kwargs.get("x", None)
|
||||
if output is None and len(args) > 0:
|
||||
output = args[0]
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
class TransformerBlockSkipHook(ModelHook):
|
||||
def __init__(self, dropout: float):
|
||||
super().__init__()
|
||||
self.dropout = dropout
|
||||
|
||||
def initialize_hook(self, module):
|
||||
self._metadata = TransformerBlockRegistry.get(unwrap_module(module).__class__)
|
||||
return module
|
||||
|
||||
def new_forward(self, module: torch.nn.Module, *args, **kwargs):
|
||||
if math.isclose(self.dropout, 1.0):
|
||||
original_hidden_states = self._metadata._get_parameter_from_args_kwargs("hidden_states", args, kwargs)
|
||||
if self._metadata.return_encoder_hidden_states_index is None:
|
||||
output = original_hidden_states
|
||||
else:
|
||||
original_encoder_hidden_states = self._metadata._get_parameter_from_args_kwargs(
|
||||
"encoder_hidden_states", args, kwargs
|
||||
)
|
||||
output = (original_hidden_states, original_encoder_hidden_states)
|
||||
else:
|
||||
output = self.fn_ref.original_forward(*args, **kwargs)
|
||||
output = torch.nn.functional.dropout(output, p=self.dropout)
|
||||
return output
|
||||
|
||||
|
||||
def apply_layer_skip(module: torch.nn.Module, config: LayerSkipConfig) -> None:
|
||||
r"""
|
||||
Apply layer skipping to internal layers of a transformer.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The transformer model to which the layer skip hook should be applied.
|
||||
config (`LayerSkipConfig`):
|
||||
The configuration for the layer skip hook.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from diffusers import apply_layer_skip_hook, CogVideoXTransformer3DModel, LayerSkipConfig
|
||||
|
||||
>>> transformer = CogVideoXTransformer3DModel.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
|
||||
>>> config = LayerSkipConfig(layer_index=[10, 20], fqn="transformer_blocks")
|
||||
>>> apply_layer_skip_hook(transformer, config)
|
||||
```
|
||||
"""
|
||||
_apply_layer_skip_hook(module, config)
|
||||
|
||||
|
||||
def _apply_layer_skip_hook(module: torch.nn.Module, config: LayerSkipConfig, name: Optional[str] = None) -> None:
|
||||
name = name or _LAYER_SKIP_HOOK
|
||||
|
||||
if config.skip_attention and config.skip_attention_scores:
|
||||
raise ValueError("Cannot set both `skip_attention` and `skip_attention_scores` to True. Please choose one.")
|
||||
if not math.isclose(config.dropout, 1.0) and config.skip_attention_scores:
|
||||
raise ValueError(
|
||||
"Cannot set `skip_attention_scores` to True when `dropout` is not 1.0. Please set `dropout` to 1.0."
|
||||
)
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
if transformer_blocks is None or not isinstance(transformer_blocks, torch.nn.ModuleList):
|
||||
raise ValueError(
|
||||
f"Could not find {config.fqn} in the provided module, or configured `fqn` (fully qualified name) does not identify "
|
||||
f"a `torch.nn.ModuleList`. Please provide a valid `fqn` that identifies a stack of transformer blocks."
|
||||
)
|
||||
if len(config.indices) == 0:
|
||||
raise ValueError("Layer index list is empty. Please provide a non-empty list of layer indices to skip.")
|
||||
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
if config.skip_attention and config.skip_ff:
|
||||
logger.debug(f"Applying TransformerBlockSkipHook to '{config.fqn}.{i}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(block)
|
||||
hook = TransformerBlockSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
elif config.skip_attention or config.skip_attention_scores:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _ATTENTION_CLASSES) and not submodule.is_cross_attention:
|
||||
logger.debug(f"Applying AttentionProcessorSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
output_fn = AttentionProcessorRegistry.get(submodule.processor.__class__).skip_processor_output_fn
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = AttentionProcessorSkipHook(output_fn, config.skip_attention_scores, config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if config.skip_ff:
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if isinstance(submodule, _FEEDFORWARD_CLASSES):
|
||||
logger.debug(f"Applying FeedForwardSkipHook to '{config.fqn}.{i}.{submodule_name}'")
|
||||
registry = HookRegistry.check_if_exists_or_initialize(submodule)
|
||||
hook = FeedForwardSkipHook(config.dropout)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
167
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
167
src/diffusers/hooks/smoothed_energy_guidance_utils.py
Normal file
@@ -0,0 +1,167 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..utils import get_logger
|
||||
from ._common import _ALL_TRANSFORMER_BLOCK_IDENTIFIERS, _ATTENTION_CLASSES, _get_submodule_from_fqn
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_SMOOTHED_ENERGY_GUIDANCE_HOOK = "smoothed_energy_guidance_hook"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SmoothedEnergyGuidanceConfig:
|
||||
r"""
|
||||
Configuration for skipping internal transformer blocks when executing a transformer model.
|
||||
|
||||
Args:
|
||||
indices (`List[int]`):
|
||||
The indices of the layer to skip. This is typically the first layer in the transformer block.
|
||||
fqn (`str`, defaults to `"auto"`):
|
||||
The fully qualified name identifying the stack of transformer blocks. Typically, this is
|
||||
`transformer_blocks`, `single_transformer_blocks`, `blocks`, `layers`, or `temporal_transformer_blocks`.
|
||||
For automatic detection, set this to `"auto"`. "auto" only works on DiT models. For UNet models, you must
|
||||
provide the correct fqn.
|
||||
_query_proj_identifiers (`List[str]`, defaults to `None`):
|
||||
The identifiers for the query projection layers. Typically, these are `to_q`, `query`, or `q_proj`. If
|
||||
`None`, `to_q` is used by default.
|
||||
"""
|
||||
|
||||
indices: List[int]
|
||||
fqn: str = "auto"
|
||||
_query_proj_identifiers: List[str] = None
|
||||
|
||||
def to_dict(self):
|
||||
return asdict(self)
|
||||
|
||||
@staticmethod
|
||||
def from_dict(data: dict) -> "SmoothedEnergyGuidanceConfig":
|
||||
return SmoothedEnergyGuidanceConfig(**data)
|
||||
|
||||
|
||||
class SmoothedEnergyGuidanceHook(ModelHook):
|
||||
def __init__(self, blur_sigma: float = 1.0, blur_threshold_inf: float = 9999.9) -> None:
|
||||
super().__init__()
|
||||
self.blur_sigma = blur_sigma
|
||||
self.blur_threshold_inf = blur_threshold_inf
|
||||
|
||||
def post_forward(self, module: torch.nn.Module, output: torch.Tensor) -> torch.Tensor:
|
||||
# Copied from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L172C31-L172C102
|
||||
kernel_size = math.ceil(6 * self.blur_sigma) + 1 - math.ceil(6 * self.blur_sigma) % 2
|
||||
smoothed_output = _gaussian_blur_2d(output, kernel_size, self.blur_sigma, self.blur_threshold_inf)
|
||||
return smoothed_output
|
||||
|
||||
|
||||
def _apply_smoothed_energy_guidance_hook(
|
||||
module: torch.nn.Module, config: SmoothedEnergyGuidanceConfig, blur_sigma: float, name: Optional[str] = None
|
||||
) -> None:
|
||||
name = name or _SMOOTHED_ENERGY_GUIDANCE_HOOK
|
||||
|
||||
if config.fqn == "auto":
|
||||
for identifier in _ALL_TRANSFORMER_BLOCK_IDENTIFIERS:
|
||||
if hasattr(module, identifier):
|
||||
config.fqn = identifier
|
||||
break
|
||||
else:
|
||||
raise ValueError(
|
||||
"Could not find a suitable identifier for the transformer blocks automatically. Please provide a valid "
|
||||
"`fqn` (fully qualified name) that identifies a stack of transformer blocks."
|
||||
)
|
||||
|
||||
if config._query_proj_identifiers is None:
|
||||
config._query_proj_identifiers = ["to_q"]
|
||||
|
||||
transformer_blocks = _get_submodule_from_fqn(module, config.fqn)
|
||||
blocks_found = False
|
||||
for i, block in enumerate(transformer_blocks):
|
||||
if i not in config.indices:
|
||||
continue
|
||||
|
||||
blocks_found = True
|
||||
|
||||
for submodule_name, submodule in block.named_modules():
|
||||
if not isinstance(submodule, _ATTENTION_CLASSES) or submodule.is_cross_attention:
|
||||
continue
|
||||
for identifier in config._query_proj_identifiers:
|
||||
query_proj = getattr(submodule, identifier, None)
|
||||
if query_proj is None or not isinstance(query_proj, torch.nn.Linear):
|
||||
continue
|
||||
logger.debug(
|
||||
f"Registering smoothed energy guidance hook on {config.fqn}.{i}.{submodule_name}.{identifier}"
|
||||
)
|
||||
registry = HookRegistry.check_if_exists_or_initialize(query_proj)
|
||||
hook = SmoothedEnergyGuidanceHook(blur_sigma)
|
||||
registry.register_hook(hook, name)
|
||||
|
||||
if not blocks_found:
|
||||
raise ValueError(
|
||||
f"Could not find any transformer blocks matching the provided indices {config.indices} and "
|
||||
f"fully qualified name '{config.fqn}'. Please check the indices and fqn for correctness."
|
||||
)
|
||||
|
||||
|
||||
# Modified from https://github.com/SusungHong/SEG-SDXL/blob/cf8256d640d5373541cfea3b3b6caf93272cf986/pipeline_seg.py#L71
|
||||
def _gaussian_blur_2d(query: torch.Tensor, kernel_size: int, sigma: float, sigma_threshold_inf: float) -> torch.Tensor:
|
||||
"""
|
||||
This implementation assumes that the input query is for visual (image/videos) tokens to apply the 2D gaussian blur.
|
||||
However, some models use joint text-visual token attention for which this may not be suitable. Additionally, this
|
||||
implementation also assumes that the visual tokens come from a square image/video. In practice, despite these
|
||||
assumptions, applying the 2D square gaussian blur on the query projections generates reasonable results for
|
||||
Smoothed Energy Guidance.
|
||||
|
||||
SEG is only supported as an experimental prototype feature for now, so the implementation may be modified in the
|
||||
future without warning or guarantee of reproducibility.
|
||||
"""
|
||||
assert query.ndim == 3
|
||||
|
||||
is_inf = sigma > sigma_threshold_inf
|
||||
batch_size, seq_len, embed_dim = query.shape
|
||||
|
||||
seq_len_sqrt = int(math.sqrt(seq_len))
|
||||
num_square_tokens = seq_len_sqrt * seq_len_sqrt
|
||||
query_slice = query[:, :num_square_tokens, :]
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, seq_len_sqrt, seq_len_sqrt)
|
||||
|
||||
if is_inf:
|
||||
kernel_size = min(kernel_size, seq_len_sqrt - (seq_len_sqrt % 2 - 1))
|
||||
kernel_size_half = (kernel_size - 1) / 2
|
||||
|
||||
x = torch.linspace(-kernel_size_half, kernel_size_half, steps=kernel_size)
|
||||
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
|
||||
kernel1d = pdf / pdf.sum()
|
||||
kernel1d = kernel1d.to(query)
|
||||
kernel2d = torch.matmul(kernel1d[:, None], kernel1d[None, :])
|
||||
kernel2d = kernel2d.expand(embed_dim, 1, kernel2d.shape[0], kernel2d.shape[1])
|
||||
|
||||
padding = [kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2]
|
||||
query_slice = F.pad(query_slice, padding, mode="reflect")
|
||||
query_slice = F.conv2d(query_slice, kernel2d, groups=embed_dim)
|
||||
else:
|
||||
query_slice[:] = query_slice.mean(dim=(-2, -1), keepdim=True)
|
||||
|
||||
query_slice = query_slice.reshape(batch_size, embed_dim, num_square_tokens)
|
||||
query_slice = query_slice.permute(0, 2, 1)
|
||||
query[:, :num_square_tokens, :] = query_slice.clone()
|
||||
|
||||
return query
|
||||
@@ -84,6 +84,7 @@ if is_torch_available():
|
||||
"IPAdapterMixin",
|
||||
"FluxIPAdapterMixin",
|
||||
"SD3IPAdapterMixin",
|
||||
"ModularIPAdapterMixin",
|
||||
]
|
||||
|
||||
_import_structure["peft"] = ["PeftAdapterMixin"]
|
||||
@@ -101,6 +102,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .ip_adapter import (
|
||||
FluxIPAdapterMixin,
|
||||
IPAdapterMixin,
|
||||
ModularIPAdapterMixin,
|
||||
SD3IPAdapterMixin,
|
||||
)
|
||||
from .lora_pipeline import (
|
||||
|
||||
@@ -354,6 +354,256 @@ class IPAdapterMixin:
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class ModularIPAdapterMixin:
|
||||
"""Mixin for handling IP Adapters."""
|
||||
|
||||
@validate_hf_hub_args
|
||||
def load_ip_adapter(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: Union[str, List[str], Dict[str, torch.Tensor]],
|
||||
subfolder: Union[str, List[str]],
|
||||
weight_name: Union[str, List[str]],
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `List[str]` or `os.PathLike` or `List[os.PathLike]` or `dict` or `List[dict]`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
||||
the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
||||
with [`ModelMixin.save_pretrained`].
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
subfolder (`str` or `List[str]`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
|
||||
list is passed, it should have the same length as `weight_name`.
|
||||
weight_name (`str` or `List[str]`):
|
||||
The name of the weight file to load. If a list is passed, it should have the same length as
|
||||
`subfolder`.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only (`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
||||
won't be downloaded from the Hub.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
||||
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
||||
allowed by Git.
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
||||
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
||||
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
||||
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
||||
argument to `True` will raise an error.
|
||||
"""
|
||||
|
||||
# handle the list inputs for multiple IP Adapters
|
||||
if not isinstance(weight_name, list):
|
||||
weight_name = [weight_name]
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, list):
|
||||
pretrained_model_name_or_path_or_dict = [pretrained_model_name_or_path_or_dict]
|
||||
if len(pretrained_model_name_or_path_or_dict) == 1:
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict * len(weight_name)
|
||||
|
||||
if not isinstance(subfolder, list):
|
||||
subfolder = [subfolder]
|
||||
if len(subfolder) == 1:
|
||||
subfolder = subfolder * len(weight_name)
|
||||
|
||||
if len(weight_name) != len(pretrained_model_name_or_path_or_dict):
|
||||
raise ValueError("`weight_name` and `pretrained_model_name_or_path_or_dict` must have the same length.")
|
||||
|
||||
if len(weight_name) != len(subfolder):
|
||||
raise ValueError("`weight_name` and `subfolder` must have the same length.")
|
||||
|
||||
# Load the main state dict first.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
"Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
|
||||
" environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
|
||||
" `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
|
||||
" install accelerate\n```\n."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
||||
raise NotImplementedError(
|
||||
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
||||
" `low_cpu_mem_usage=False`."
|
||||
)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
state_dicts = []
|
||||
for pretrained_model_name_or_path_or_dict, weight_name, subfolder in zip(
|
||||
pretrained_model_name_or_path_or_dict, weight_name, subfolder
|
||||
):
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
if weight_name.endswith(".safetensors"):
|
||||
state_dict = {"image_proj": {}, "ip_adapter": {}}
|
||||
with safe_open(model_file, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
if key.startswith("image_proj."):
|
||||
state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
|
||||
elif key.startswith("ip_adapter."):
|
||||
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
|
||||
else:
|
||||
state_dict = load_state_dict(model_file)
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
keys = list(state_dict.keys())
|
||||
if "image_proj" not in keys and "ip_adapter" not in keys:
|
||||
raise ValueError("Required keys are (`image_proj` and `ip_adapter`) missing from the state dict.")
|
||||
|
||||
state_dicts.append(state_dict)
|
||||
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage)
|
||||
|
||||
extra_loras = unet._load_ip_adapter_loras(state_dicts)
|
||||
if extra_loras != {}:
|
||||
if not USE_PEFT_BACKEND:
|
||||
logger.warning("PEFT backend is required to load these weights.")
|
||||
else:
|
||||
# apply the IP Adapter Face ID LoRA weights
|
||||
peft_config = getattr(unet, "peft_config", {})
|
||||
for k, lora in extra_loras.items():
|
||||
if f"faceid_{k}" not in peft_config:
|
||||
self.load_lora_weights(lora, adapter_name=f"faceid_{k}")
|
||||
self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0])
|
||||
|
||||
def set_ip_adapter_scale(self, scale):
|
||||
"""
|
||||
Set IP-Adapter scales per-transformer block. Input `scale` could be a single config or a list of configs for
|
||||
granular control over each IP-Adapter behavior. A config can be a float or a dictionary.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
# To use original IP-Adapter
|
||||
scale = 1.0
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style block only
|
||||
scale = {
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style+layout blocks
|
||||
scale = {
|
||||
"down": {"block_2": [0.0, 1.0]},
|
||||
"up": {"block_0": [0.0, 1.0, 0.0]},
|
||||
}
|
||||
pipeline.set_ip_adapter_scale(scale)
|
||||
|
||||
# To use style and layout from 2 reference images
|
||||
scales = [{"down": {"block_2": [0.0, 1.0]}}, {"up": {"block_0": [0.0, 1.0, 0.0]}}]
|
||||
pipeline.set_ip_adapter_scale(scales)
|
||||
```
|
||||
"""
|
||||
unet_name = getattr(self, "unet_name", "unet")
|
||||
unet = getattr(self, unet_name)
|
||||
if not isinstance(scale, list):
|
||||
scale = [scale]
|
||||
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)
|
||||
|
||||
for attn_name, attn_processor in unet.attn_processors.items():
|
||||
if isinstance(
|
||||
attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
):
|
||||
if len(scale_configs) != len(attn_processor.scale):
|
||||
raise ValueError(
|
||||
f"Cannot assign {len(scale_configs)} scale_configs to {len(attn_processor.scale)} IP-Adapter."
|
||||
)
|
||||
elif len(scale_configs) == 1:
|
||||
scale_configs = scale_configs * len(attn_processor.scale)
|
||||
for i, scale_config in enumerate(scale_configs):
|
||||
if isinstance(scale_config, dict):
|
||||
for k, s in scale_config.items():
|
||||
if attn_name.startswith(k):
|
||||
attn_processor.scale[i] = s
|
||||
else:
|
||||
attn_processor.scale[i] = scale_config
|
||||
|
||||
def unload_ip_adapter(self):
|
||||
"""
|
||||
Unloads the IP Adapter weights
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
>>> # Assuming `pipeline` is already loaded with the IP Adapter weights.
|
||||
>>> pipeline.unload_ip_adapter()
|
||||
>>> ...
|
||||
```
|
||||
"""
|
||||
|
||||
# remove hidden encoder
|
||||
if self.unet is None:
|
||||
return
|
||||
|
||||
self.unet.encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = None
|
||||
|
||||
# Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
|
||||
if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
|
||||
self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
|
||||
self.unet.text_encoder_hid_proj = None
|
||||
self.unet.config.encoder_hid_dim_type = "text_proj"
|
||||
|
||||
# restore original Unet attention processors layers
|
||||
attn_procs = {}
|
||||
for name, value in self.unet.attn_processors.items():
|
||||
attn_processor_class = (
|
||||
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnProcessor()
|
||||
)
|
||||
attn_procs[name] = (
|
||||
attn_processor_class
|
||||
if isinstance(
|
||||
value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)
|
||||
)
|
||||
else value.__class__()
|
||||
)
|
||||
self.unet.set_attn_processor(attn_procs)
|
||||
|
||||
|
||||
class FluxIPAdapterMixin:
|
||||
"""Mixin for handling Flux IP Adapters."""
|
||||
|
||||
|
||||
@@ -330,6 +330,8 @@ def _load_lora_into_text_encoder(
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
@@ -391,7 +393,9 @@ def _load_lora_into_text_encoder(
|
||||
adapter_name = get_adapter_name(text_encoder)
|
||||
|
||||
# <Unsafe code
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = _func_optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = _func_optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
# inject LoRA layers and load the state dict
|
||||
# in transformers we automatically check whether the adapter name is already in use or not
|
||||
text_encoder.load_adapter(
|
||||
@@ -410,6 +414,10 @@ def _load_lora_into_text_encoder(
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -433,30 +441,38 @@ def _func_optionally_disable_offloading(_pipeline):
|
||||
|
||||
Returns:
|
||||
tuple:
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
||||
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` or `is_group_offload` is True.
|
||||
"""
|
||||
from ..hooks.group_offloading import _is_group_offload_enabled
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if _pipeline is not None and _pipeline.hf_device_map is None:
|
||||
for _, component in _pipeline.components.items():
|
||||
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
||||
if not is_model_cpu_offload:
|
||||
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
||||
if not is_sequential_cpu_offload:
|
||||
is_sequential_cpu_offload = (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
if not isinstance(component, nn.Module):
|
||||
continue
|
||||
is_group_offload = is_group_offload or _is_group_offload_enabled(component)
|
||||
if not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
is_model_cpu_offload = is_model_cpu_offload or isinstance(component._hf_hook, CpuOffload)
|
||||
is_sequential_cpu_offload = is_sequential_cpu_offload or (
|
||||
isinstance(component._hf_hook, AlignDevicesHook)
|
||||
or hasattr(component._hf_hook, "hooks")
|
||||
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
if is_sequential_cpu_offload or is_model_cpu_offload:
|
||||
logger.info(
|
||||
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
||||
)
|
||||
for _, component in _pipeline.components.items():
|
||||
if not isinstance(component, nn.Module) or not hasattr(component, "_hf_hook"):
|
||||
continue
|
||||
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
||||
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
||||
return (is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload)
|
||||
|
||||
|
||||
class LoraBaseMixin:
|
||||
@@ -921,6 +937,27 @@ class LoraBaseMixin:
|
||||
Moves the LoRAs listed in `adapter_names` to a target device. Useful for offloading the LoRA to the CPU in case
|
||||
you want to load multiple adapters and free some GPU memory.
|
||||
|
||||
After offloading the LoRA adapters to CPU, as long as the rest of the model is still on GPU, the LoRA adapters
|
||||
can no longer be used for inference, as that would cause a device mismatch. Remember to set the device back to
|
||||
GPU before using those LoRA adapters for inference.
|
||||
|
||||
```python
|
||||
>>> pipe.load_lora_weights(path_1, adapter_name="adapter-1")
|
||||
>>> pipe.load_lora_weights(path_2, adapter_name="adapter-2")
|
||||
>>> pipe.set_adapters("adapter-1")
|
||||
>>> image_1 = pipe(**kwargs)
|
||||
>>> # switch to adapter-2, offload adapter-1
|
||||
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cpu")
|
||||
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cuda:0")
|
||||
>>> pipe.set_adapters("adapter-2")
|
||||
>>> image_2 = pipe(**kwargs)
|
||||
>>> # switch back to adapter-1, offload adapter-2
|
||||
>>> pipeline.set_lora_device(adapter_names=["adapter-2"], device="cpu")
|
||||
>>> pipeline.set_lora_device(adapter_names=["adapter-1"], device="cuda:0")
|
||||
>>> pipe.set_adapters("adapter-1")
|
||||
>>> ...
|
||||
```
|
||||
|
||||
Args:
|
||||
adapter_names (`List[str]`):
|
||||
List of adapters to send device to.
|
||||
@@ -936,6 +973,10 @@ class LoraBaseMixin:
|
||||
for module in model.modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
for adapter_name in adapter_names:
|
||||
if adapter_name not in module.lora_A:
|
||||
# it is sufficient to check lora_A
|
||||
continue
|
||||
|
||||
module.lora_A[adapter_name].to(device)
|
||||
module.lora_B[adapter_name].to(device)
|
||||
# this is a param, not a module, so device placement is not in-place -> re-assign
|
||||
|
||||
@@ -1346,6 +1346,228 @@ def _convert_bfl_flux_control_lora_to_diffusers(original_state_dict):
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_fal_kontext_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {}
|
||||
original_state_dict_keys = list(original_state_dict.keys())
|
||||
num_layers = 19
|
||||
num_single_layers = 38
|
||||
inner_dim = 3072
|
||||
mlp_ratio = 4.0
|
||||
|
||||
# double transformer blocks
|
||||
for i in range(num_layers):
|
||||
block_prefix = f"transformer_blocks.{i}."
|
||||
original_block_prefix = "base_model.model."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norms
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"double_blocks.{i}.img_mod.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm1.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mod.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}norm1_context.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mod.lin.{lora_key}.weight"
|
||||
)
|
||||
|
||||
# Q, K, V
|
||||
if lora_key == "lora_A":
|
||||
sample_lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_lora_weight])
|
||||
|
||||
context_lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat(
|
||||
[context_lora_weight]
|
||||
)
|
||||
else:
|
||||
sample_q, sample_k, sample_v = torch.chunk(
|
||||
original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.weight"
|
||||
),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([sample_q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([sample_k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([sample_v])
|
||||
|
||||
context_q, context_k, context_v = torch.chunk(
|
||||
original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.weight"
|
||||
),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.weight"] = torch.cat([context_q])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.weight"] = torch.cat([context_k])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.weight"] = torch.cat([context_v])
|
||||
|
||||
if f"double_blocks.{i}.img_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.img_attn.qkv.{lora_key}.bias"),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([sample_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([sample_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([sample_v_bias])
|
||||
|
||||
if f"double_blocks.{i}.txt_attn.qkv.{lora_key}.bias" in original_state_dict_keys:
|
||||
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
||||
original_state_dict.pop(f"{original_block_prefix}double_blocks.{i}.txt_attn.qkv.{lora_key}.bias"),
|
||||
3,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.add_q_proj.{lora_key}.bias"] = torch.cat([context_q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_k_proj.{lora_key}.bias"] = torch.cat([context_k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.add_v_proj.{lora_key}.bias"] = torch.cat([context_v_bias])
|
||||
|
||||
# ff img_mlp
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.0.{lora_key}.bias"
|
||||
)
|
||||
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}ff_context.net.2.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_mlp.2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_out.0.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.img_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}attn.to_add_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}double_blocks.{i}.txt_attn.proj.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# single transformer blocks
|
||||
for i in range(num_single_layers):
|
||||
block_prefix = f"single_transformer_blocks.{i}."
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
# norm.linear <- single_blocks.0.modulation.lin
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}norm.linear.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.modulation.lin.{lora_key}.bias"
|
||||
)
|
||||
|
||||
# Q, K, V, mlp
|
||||
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
||||
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
||||
|
||||
if lora_key == "lora_A":
|
||||
lora_weight = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([lora_weight])
|
||||
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
lora_bias = original_state_dict.pop(f"single_blocks.{i}.linear1.{lora_key}.bias")
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([lora_bias])
|
||||
else:
|
||||
q, k, v, mlp = torch.split(
|
||||
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.weight"),
|
||||
split_size,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.weight"] = torch.cat([q])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.weight"] = torch.cat([k])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.weight"] = torch.cat([v])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.weight"] = torch.cat([mlp])
|
||||
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias" in original_state_dict_keys:
|
||||
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
||||
original_state_dict.pop(f"{original_block_prefix}single_blocks.{i}.linear1.{lora_key}.bias"),
|
||||
split_size,
|
||||
dim=0,
|
||||
)
|
||||
converted_state_dict[f"{block_prefix}attn.to_q.{lora_key}.bias"] = torch.cat([q_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_k.{lora_key}.bias"] = torch.cat([k_bias])
|
||||
converted_state_dict[f"{block_prefix}attn.to_v.{lora_key}.bias"] = torch.cat([v_bias])
|
||||
converted_state_dict[f"{block_prefix}proj_mlp.{lora_key}.bias"] = torch.cat([mlp_bias])
|
||||
|
||||
# output projections.
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"{block_prefix}proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}single_blocks.{i}.linear2.{lora_key}.bias"
|
||||
)
|
||||
|
||||
for lora_key in ["lora_A", "lora_B"]:
|
||||
converted_state_dict[f"proj_out.{lora_key}.weight"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}final_layer.linear.{lora_key}.weight"
|
||||
)
|
||||
if f"{original_block_prefix}final_layer.linear.{lora_key}.bias" in original_state_dict_keys:
|
||||
converted_state_dict[f"proj_out.{lora_key}.bias"] = original_state_dict.pop(
|
||||
f"{original_block_prefix}final_layer.linear.{lora_key}.bias"
|
||||
)
|
||||
|
||||
if len(original_state_dict) > 0:
|
||||
raise ValueError(f"`original_state_dict` should be empty at this point but has {original_state_dict.keys()=}.")
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
converted_state_dict[f"transformer.{key}"] = converted_state_dict.pop(key)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def _convert_hunyuan_video_lora_to_diffusers(original_state_dict):
|
||||
converted_state_dict = {k: original_state_dict.pop(k) for k in list(original_state_dict.keys())}
|
||||
|
||||
@@ -1603,24 +1825,22 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
|
||||
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
|
||||
lora_down_key = "lora_A" if any("lora_A" in k for k in original_state_dict) else "lora_down"
|
||||
lora_up_key = "lora_B" if any("lora_B" in k for k in original_state_dict) else "lora_up"
|
||||
has_time_projection_weight = any(
|
||||
k.startswith("time_projection") and k.endswith(".weight") for k in original_state_dict
|
||||
)
|
||||
|
||||
diff_keys = [k for k in original_state_dict if k.endswith((".diff_b", ".diff"))]
|
||||
if diff_keys:
|
||||
for diff_k in diff_keys:
|
||||
param = original_state_dict[diff_k]
|
||||
# The magnitudes of the .diff-ending weights are very low (most are below 1e-4, some are upto 1e-3,
|
||||
# and 2 of them are about 1.6e-2 [the case with AccVideo lora]). The low magnitudes mostly correspond
|
||||
# to norm layers. Ignoring them is the best option at the moment until a better solution is found. It
|
||||
# is okay to ignore because they do not affect the model output in a significant manner.
|
||||
threshold = 1.6e-2
|
||||
absdiff = param.abs().max() - param.abs().min()
|
||||
all_zero = torch.all(param == 0).item()
|
||||
all_absdiff_lower_than_threshold = absdiff < threshold
|
||||
if all_zero or all_absdiff_lower_than_threshold:
|
||||
logger.debug(
|
||||
f"Removed {diff_k} key from the state dict as it's all zeros, or values lower than hardcoded threshold."
|
||||
)
|
||||
original_state_dict.pop(diff_k)
|
||||
for key in list(original_state_dict.keys()):
|
||||
if key.endswith((".diff", ".diff_b")) and "norm" in key:
|
||||
# NOTE: we don't support this because norm layer diff keys are just zeroed values. We can support it
|
||||
# in future if needed and they are not zeroed.
|
||||
original_state_dict.pop(key)
|
||||
logger.debug(f"Removing {key} key from the state dict as it is a norm diff key. This is unsupported.")
|
||||
|
||||
if "time_projection" in key and not has_time_projection_weight:
|
||||
# AccVideo lora has diff bias keys but not the weight keys. This causes a weird problem where
|
||||
# our lora config adds the time proj lora layers, but we don't have the weights for them.
|
||||
# CausVid lora has the weight keys and the bias keys.
|
||||
original_state_dict.pop(key)
|
||||
|
||||
# For the `diff_b` keys, we treat them as lora_bias.
|
||||
# https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraConfig.lora_bias
|
||||
|
||||
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
|
||||
)
|
||||
from .lora_conversion_utils import (
|
||||
_convert_bfl_flux_control_lora_to_diffusers,
|
||||
_convert_fal_kontext_lora_to_diffusers,
|
||||
_convert_hunyuan_video_lora_to_diffusers,
|
||||
_convert_kohya_flux_lora_to_diffusers,
|
||||
_convert_musubi_wan_lora_to_diffusers,
|
||||
@@ -2062,6 +2063,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
is_fal_kontext = any("base_model" in k for k in state_dict)
|
||||
if is_fal_kontext:
|
||||
state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
|
||||
return cls._prepare_outputs(
|
||||
state_dict,
|
||||
metadata=metadata,
|
||||
alphas=None,
|
||||
return_alphas=return_alphas,
|
||||
return_metadata=return_lora_metadata,
|
||||
)
|
||||
|
||||
# For state dicts like
|
||||
# https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
@@ -163,6 +163,8 @@ class PeftAdapterMixin:
|
||||
from peft import inject_adapter_in_model, set_peft_model_state_dict
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -243,20 +245,29 @@ class PeftAdapterMixin:
|
||||
k.removeprefix(f"{prefix}."): v for k, v in network_alphas.items() if k in alpha_keys
|
||||
}
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(state_dict, network_alphas, metadata, rank)
|
||||
|
||||
# adapter_name
|
||||
if adapter_name is None:
|
||||
adapter_name = get_adapter_name(self)
|
||||
|
||||
# create LoraConfig
|
||||
lora_config = _create_lora_config(
|
||||
state_dict,
|
||||
network_alphas,
|
||||
metadata,
|
||||
rank,
|
||||
model_state_dict=self.state_dict(),
|
||||
adapter_name=adapter_name,
|
||||
)
|
||||
|
||||
# <Unsafe code
|
||||
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
||||
# Now we remove any existing hooks to `_pipeline`.
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error.
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -347,6 +358,10 @@ class PeftAdapterMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
if prefix is not None and not state_dict:
|
||||
@@ -681,11 +696,16 @@ class PeftAdapterMixin:
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
||||
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
from ..utils import recurse_remove_peft_layers
|
||||
|
||||
recurse_remove_peft_layers(self)
|
||||
if hasattr(self, "peft_config"):
|
||||
del self.peft_config
|
||||
if hasattr(self, "_hf_peft_config_loaded"):
|
||||
self._hf_peft_config_loaded = None
|
||||
|
||||
_maybe_remove_and_reapply_group_offloading(self)
|
||||
|
||||
def disable_lora(self):
|
||||
"""
|
||||
|
||||
@@ -31,6 +31,7 @@ from .single_file_utils import (
|
||||
convert_autoencoder_dc_checkpoint_to_diffusers,
|
||||
convert_chroma_transformer_checkpoint_to_diffusers,
|
||||
convert_controlnet_checkpoint,
|
||||
convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
convert_flux_transformer_checkpoint_to_diffusers,
|
||||
convert_hidream_transformer_to_diffusers,
|
||||
convert_hunyuan_video_transformer_to_diffusers,
|
||||
@@ -135,6 +136,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanVACETransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
@@ -143,6 +148,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_hidream_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"CosmosTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_cosmos_transformer_checkpoint_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -126,7 +126,18 @@ CHECKPOINT_KEY_NAMES = {
|
||||
],
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"wan_vace": "vace_blocks.0.after_proj.bias",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
|
||||
"net.extra_pos_embedder.pos_emb_h",
|
||||
],
|
||||
"cosmos-2.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
"net.blocks.0.self_attn.q_proj.weight",
|
||||
"net.pos_embedder.dim_spatial_range",
|
||||
],
|
||||
}
|
||||
|
||||
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
@@ -192,7 +203,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
||||
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
||||
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
||||
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
|
||||
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
|
||||
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
|
||||
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
|
||||
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
||||
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
||||
}
|
||||
|
||||
# Use to configure model sample size when original config is provided
|
||||
@@ -698,17 +719,44 @@ def infer_diffusers_model_type(checkpoint):
|
||||
else:
|
||||
target_key = "patch_embedding.weight"
|
||||
|
||||
if checkpoint[target_key].shape[0] == 1536:
|
||||
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
|
||||
if checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-vace-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120:
|
||||
model_type = "wan-vace-14B"
|
||||
|
||||
elif checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
model_type = "wan-t2v-14B"
|
||||
else:
|
||||
model_type = "wan-i2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
||||
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
||||
model_type = "wan-t2v-14B"
|
||||
|
||||
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
||||
model_type = "hidream"
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
|
||||
|
||||
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
|
||||
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
|
||||
if x_embedder_shape[1] == 68:
|
||||
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
|
||||
elif x_embedder_shape[1] == 72:
|
||||
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
|
||||
else:
|
||||
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
||||
|
||||
else:
|
||||
model_type = "v1"
|
||||
|
||||
@@ -3093,6 +3141,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# For the VACE model
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
@@ -3479,3 +3530,116 @@ def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
||||
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
||||
|
||||
def remove_keys_(key: str, state_dict):
|
||||
state_dict.pop(key)
|
||||
|
||||
def rename_transformer_blocks_(key: str, state_dict):
|
||||
block_index = int(key.split(".")[1].removeprefix("block"))
|
||||
new_key = key
|
||||
old_prefix = f"blocks.block{block_index}"
|
||||
new_prefix = f"transformer_blocks.{block_index}"
|
||||
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"affline_norm": "time_embed.norm",
|
||||
".blocks.0.block.attn": ".attn1",
|
||||
".blocks.1.block.attn": ".attn2",
|
||||
".blocks.2.block": ".ff",
|
||||
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
||||
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
||||
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
||||
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
||||
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
||||
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
||||
"to_q.0": "to_q",
|
||||
"to_q.1": "norm_q",
|
||||
"to_k.0": "to_k",
|
||||
"to_k.1": "norm_k",
|
||||
"to_v.0": "to_v",
|
||||
"layer1": "net.0.proj",
|
||||
"layer2": "net.2",
|
||||
"proj.1": "proj",
|
||||
"x_embedder": "patch_embed",
|
||||
"extra_pos_embedder": "learnable_pos_embed",
|
||||
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
||||
"blocks.block": rename_transformer_blocks_,
|
||||
"logvar.0.freqs": remove_keys_,
|
||||
"logvar.0.phases": remove_keys_,
|
||||
"logvar.1.weight": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
}
|
||||
|
||||
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
||||
"t_embedder.1": "time_embed.t_embedder",
|
||||
"t_embedding_norm": "time_embed.norm",
|
||||
"blocks": "transformer_blocks",
|
||||
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
||||
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
||||
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
||||
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
||||
"adaln_modulation_mlp.1": "norm3.linear_1",
|
||||
"adaln_modulation_mlp.2": "norm3.linear_2",
|
||||
"self_attn": "attn1",
|
||||
"cross_attn": "attn2",
|
||||
"q_proj": "to_q",
|
||||
"k_proj": "to_k",
|
||||
"v_proj": "to_v",
|
||||
"output_proj": "to_out.0",
|
||||
"q_norm": "norm_q",
|
||||
"k_norm": "norm_k",
|
||||
"mlp.layer1": "ff.net.0.proj",
|
||||
"mlp.layer2": "ff.net.2",
|
||||
"x_embedder.proj.1": "patch_embed.proj",
|
||||
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
||||
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
||||
"final_layer.linear": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
||||
"accum_video_sample_counter": remove_keys_,
|
||||
"accum_image_sample_counter": remove_keys_,
|
||||
"accum_iteration": remove_keys_,
|
||||
"accum_train_in_hours": remove_keys_,
|
||||
"pos_embedder.seq": remove_keys_,
|
||||
"pos_embedder.dim_spatial_range": remove_keys_,
|
||||
"pos_embedder.dim_temporal_range": remove_keys_,
|
||||
"_extra_state": remove_keys_,
|
||||
}
|
||||
|
||||
PREFIX_KEY = "net."
|
||||
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
||||
else:
|
||||
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
new_key = key[:]
|
||||
if new_key.startswith(PREFIX_KEY):
|
||||
new_key = new_key.removeprefix(PREFIX_KEY)
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
||||
|
||||
state_dict_keys = list(converted_state_dict.keys())
|
||||
for key in state_dict_keys:
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, converted_state_dict)
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
@@ -131,6 +131,8 @@ class UNet2DConditionLoadersMixin:
|
||||
)
|
||||
```
|
||||
"""
|
||||
from ..hooks.group_offloading import _maybe_remove_and_reapply_group_offloading
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
@@ -203,6 +205,7 @@ class UNet2DConditionLoadersMixin:
|
||||
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
|
||||
if is_lora:
|
||||
deprecation_message = "Using the `load_attn_procs()` method has been deprecated and will be removed in a future version. Please use `load_lora_adapter()`."
|
||||
@@ -211,7 +214,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if is_custom_diffusion:
|
||||
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
||||
elif is_lora:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._process_lora(
|
||||
state_dict=state_dict,
|
||||
unet_identifier_key=self.unet_name,
|
||||
network_alphas=network_alphas,
|
||||
@@ -230,7 +233,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
||||
if is_custom_diffusion and _pipeline is not None:
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline=_pipeline
|
||||
)
|
||||
|
||||
# only custom diffusion needs to set attn processors
|
||||
self.set_attn_processor(attn_processors)
|
||||
@@ -241,6 +246,10 @@ class UNet2DConditionLoadersMixin:
|
||||
_pipeline.enable_model_cpu_offload()
|
||||
elif is_sequential_cpu_offload:
|
||||
_pipeline.enable_sequential_cpu_offload()
|
||||
elif is_group_offload:
|
||||
for component in _pipeline.components.values():
|
||||
if isinstance(component, torch.nn.Module):
|
||||
_maybe_remove_and_reapply_group_offloading(component)
|
||||
# Unsafe code />
|
||||
|
||||
def _process_custom_diffusion(self, state_dict):
|
||||
@@ -307,6 +316,7 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
is_model_cpu_offload = False
|
||||
is_sequential_cpu_offload = False
|
||||
is_group_offload = False
|
||||
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
||||
|
||||
if len(state_dict_to_be_used) > 0:
|
||||
@@ -356,7 +366,9 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
||||
# otherwise loading LoRA weights will lead to an error
|
||||
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
||||
is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload = self._optionally_disable_offloading(
|
||||
_pipeline
|
||||
)
|
||||
peft_kwargs = {}
|
||||
if is_peft_version(">=", "0.13.1"):
|
||||
peft_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
@@ -389,7 +401,7 @@ class UNet2DConditionLoadersMixin:
|
||||
if warn_msg:
|
||||
logger.warning(warn_msg)
|
||||
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload
|
||||
return is_model_cpu_offload, is_sequential_cpu_offload, is_group_offload
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_base.LoraBaseMixin._optionally_disable_offloading
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user