mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-06 08:41:40 +08:00
Compare commits
117 Commits
revert-MT5
...
Ando233-ra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
403d3f20f7 | ||
|
|
441224ac00 | ||
|
|
20364fe5a2 | ||
|
|
3902145b38 | ||
|
|
af0bed007a | ||
|
|
5570f817da | ||
|
|
33f785b444 | ||
|
|
06ccde9490 | ||
|
|
ed9bcfd7a9 | ||
|
|
05d3edca66 | ||
|
|
f4ec0f1443 | ||
|
|
fa016b196c | ||
|
|
33d98a85da | ||
|
|
88798242bc | ||
|
|
14d918ee88 | ||
|
|
bc59324a2f | ||
|
|
b9a5266cec | ||
|
|
876e930780 | ||
|
|
df1af7d907 | ||
|
|
af75d8b9e2 | ||
|
|
e805be989e | ||
|
|
3958fda3bf | ||
|
|
196f8a36c7 | ||
|
|
4a2833c1c2 | ||
|
|
1fe688a651 | ||
|
|
9c0f96b303 | ||
|
|
bc71889852 | ||
|
|
3a6689518f | ||
|
|
bbbcdd87bd | ||
|
|
47e8faf3b9 | ||
|
|
c2fdd2d048 | ||
|
|
84ff061b1d | ||
|
|
3fd14f1acf | ||
|
|
e7fe4ce92f | ||
|
|
3d9085565b | ||
|
|
5b54496131 | ||
|
|
fcdd759e39 | ||
|
|
5817416a19 | ||
|
|
e834e498b2 | ||
|
|
f15873af72 | ||
|
|
bff48d317e | ||
|
|
cd86873ea6 | ||
|
|
34787e5b9b | ||
|
|
9ada5768e5 | ||
|
|
8861a8082a | ||
|
|
03e757ca73 | ||
|
|
c717498fa3 | ||
|
|
39188248a7 | ||
|
|
9b97932424 | ||
|
|
680076fcc0 | ||
|
|
1b4a43f59d | ||
|
|
6a78767864 | ||
|
|
5910a1cc6c | ||
|
|
40e96454f1 | ||
|
|
47455bd133 | ||
|
|
663b580418 | ||
|
|
d965cabe79 | ||
|
|
5c85781519 | ||
|
|
c71cb44299 | ||
|
|
dca59233f6 | ||
|
|
b3ffd6344a | ||
|
|
7debd07541 | ||
|
|
97c2c6e397 | ||
|
|
212db7b999 | ||
|
|
31058485f1 | ||
|
|
b297868201 | ||
|
|
aac94befce | ||
|
|
1f6ac1c3d1 | ||
|
|
5e94d62eb4 | ||
|
|
7ab2011759 | ||
|
|
4890e9bf70 | ||
|
|
f1e5914120 | ||
|
|
28a02eb226 | ||
|
|
61885f37e3 | ||
|
|
c68b812cb0 | ||
|
|
a80b19218b | ||
|
|
01de02e8b4 | ||
|
|
db2d7e7bc4 | ||
|
|
f8d3db9ca7 | ||
|
|
99daaa802d | ||
|
|
fe78a7b7c6 | ||
|
|
53e1d0e458 | ||
|
|
a577ec36df | ||
|
|
6875490c3b | ||
|
|
64734b2115 | ||
|
|
f81e653197 | ||
|
|
d8b2983b9e | ||
|
|
bcbbded7c3 | ||
|
|
d06b501850 | ||
|
|
a4fc9f64b2 | ||
|
|
fc5295951a | ||
|
|
96520c4ff1 | ||
|
|
35086ac06a | ||
|
|
e390646f25 | ||
|
|
59e7a46928 | ||
|
|
d3cbd5a60b | ||
|
|
906d79a432 | ||
|
|
9522e68a5b | ||
|
|
6a9bde6964 | ||
|
|
e6d449933d | ||
|
|
7cbbf271f3 | ||
|
|
202b14f6a4 | ||
|
|
0d59b22732 | ||
|
|
d7cb12470b | ||
|
|
f06ea7a901 | ||
|
|
25bc9e334c | ||
|
|
24acab0bcc | ||
|
|
0850c8cdc9 | ||
|
|
3ecf89d044 | ||
|
|
b0dc51da31 | ||
|
|
c919ec0611 | ||
|
|
3c7506b294 | ||
|
|
19ab0ecb9e | ||
|
|
5b00a18374 | ||
|
|
a3926d77d7 | ||
|
|
f82cecc298 | ||
|
|
382aad0a6c |
14
.github/workflows/benchmark.yml
vendored
14
.github/workflows/benchmark.yml
vendored
@@ -62,20 +62,6 @@ jobs:
|
||||
with:
|
||||
name: benchmark_test_reports
|
||||
path: benchmarks/${{ env.BASE_PATH }}
|
||||
|
||||
# TODO: enable this once the connection problem has been resolved.
|
||||
- name: Update benchmarking results to DB
|
||||
env:
|
||||
PGDATABASE: metrics
|
||||
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
|
||||
PGUSER: transformers_benchmarks
|
||||
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
|
||||
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
|
||||
run: |
|
||||
git config --global --add safe.directory /__w/diffusers/diffusers
|
||||
commit_id=$GITHUB_SHA
|
||||
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
|
||||
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
|
||||
|
||||
- name: Report success status
|
||||
if: ${{ success() }}
|
||||
|
||||
13
.github/workflows/pr_tests.yml
vendored
13
.github/workflows/pr_tests.yml
vendored
@@ -92,7 +92,6 @@ jobs:
|
||||
runner: aws-general-8-plus
|
||||
image: diffusers/diffusers-pytorch-cpu
|
||||
report: torch_example_cpu
|
||||
|
||||
name: ${{ matrix.config.name }}
|
||||
|
||||
runs-on:
|
||||
@@ -115,8 +114,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
|
||||
- name: Environment
|
||||
@@ -218,8 +216,6 @@ jobs:
|
||||
|
||||
run_lora_tests:
|
||||
needs: [check_code_quality, check_repository_consistency]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
|
||||
name: LoRA tests with PEFT main
|
||||
|
||||
@@ -247,9 +243,8 @@ jobs:
|
||||
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
|
||||
uv pip install -U tokenizers
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -275,6 +270,6 @@ jobs:
|
||||
if: ${{ always() }}
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: pr_main_test_reports
|
||||
name: pr_lora_test_reports
|
||||
path: reports
|
||||
|
||||
|
||||
14
.github/workflows/pr_tests_gpu.yml
vendored
14
.github/workflows/pr_tests_gpu.yml
vendored
@@ -131,8 +131,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -199,16 +198,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -269,8 +262,7 @@ jobs:
|
||||
nvidia-smi
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip install -e ".[quality,training]"
|
||||
|
||||
- name: Environment
|
||||
|
||||
14
.github/workflows/push_tests.yml
vendored
14
.github/workflows/push_tests.yml
vendored
@@ -76,8 +76,7 @@ jobs:
|
||||
run: |
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
@@ -126,16 +125,10 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
|
||||
uv pip install pip==25.2 setuptools==80.10.2
|
||||
uv pip install --no-build-isolation k-diffusion==0.0.12
|
||||
uv pip install --upgrade pip setuptools
|
||||
# Install the rest as normal
|
||||
uv pip install -e ".[quality]"
|
||||
uv pip install peft@git+https://github.com/huggingface/peft.git
|
||||
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -187,8 +180,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv pip install -e ".[quality,training]"
|
||||
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
|
||||
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
|
||||
- name: Environment
|
||||
run: |
|
||||
python utils/print_env.py
|
||||
|
||||
2
.github/workflows/push_tests_mps.yml
vendored
2
.github/workflows/push_tests_mps.yml
vendored
@@ -41,7 +41,7 @@ jobs:
|
||||
shell: arch -arch arm64 bash {0}
|
||||
run: |
|
||||
${CONDA_RUN} python -m pip install --upgrade pip uv
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
|
||||
${CONDA_RUN} python -m uv pip install -e ".[quality]"
|
||||
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
|
||||
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
|
||||
${CONDA_RUN} python -m uv pip install transformers --upgrade
|
||||
|
||||
3
.github/workflows/pypi_publish.yaml
vendored
3
.github/workflows/pypi_publish.yaml
vendored
@@ -54,7 +54,6 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install -U setuptools wheel twine
|
||||
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install -U transformers
|
||||
|
||||
- name: Build the dist files
|
||||
run: python setup.py bdist_wheel && python setup.py sdist
|
||||
@@ -69,6 +68,8 @@ jobs:
|
||||
run: |
|
||||
pip install diffusers && pip uninstall diffusers -y
|
||||
pip install -i https://test.pypi.org/simple/ diffusers
|
||||
pip install -U transformers
|
||||
python utils/print_env.py
|
||||
python -c "from diffusers import __version__; print(__version__)"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
|
||||
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"
|
||||
|
||||
@@ -1,166 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import gpustat
|
||||
import pandas as pd
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
|
||||
register_adapter(dict, Json)
|
||||
|
||||
FINAL_CSV_FILENAME = "collated_results.csv"
|
||||
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
|
||||
BENCHMARKS_TABLE_NAME = "benchmarks"
|
||||
MEASUREMENTS_TABLE_NAME = "model_measurements"
|
||||
|
||||
|
||||
def _init_benchmark(conn, branch, commit_id, commit_msg):
|
||||
gpu_stats = gpustat.GPUStatCollection.new_query()
|
||||
metadata = {"gpu_name": gpu_stats[0]["name"]}
|
||||
repository = "huggingface/diffusers"
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
|
||||
(repository, branch, commit_id, commit_msg, metadata),
|
||||
)
|
||||
benchmark_id = cur.fetchone()[0]
|
||||
print(f"Initialised benchmark #{benchmark_id}")
|
||||
return benchmark_id
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"branch",
|
||||
type=str,
|
||||
help="The branch name on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_id",
|
||||
type=str,
|
||||
help="The commit hash on which the benchmarking is performed.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"commit_msg",
|
||||
type=str,
|
||||
help="The commit message associated with the commit, truncated to 70 characters.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
try:
|
||||
conn = psycopg2.connect(
|
||||
host=os.getenv("PGHOST"),
|
||||
database=os.getenv("PGDATABASE"),
|
||||
user=os.getenv("PGUSER"),
|
||||
password=os.getenv("PGPASSWORD"),
|
||||
)
|
||||
print("DB connection established successfully.")
|
||||
except Exception as e:
|
||||
print(f"Problem during DB init: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
try:
|
||||
benchmark_id = _init_benchmark(
|
||||
conn=conn,
|
||||
branch=args.branch,
|
||||
commit_id=args.commit_id,
|
||||
commit_msg=args.commit_msg,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Problem during initializing benchmark: {e}")
|
||||
sys.exit(1)
|
||||
|
||||
cur = conn.cursor()
|
||||
|
||||
df = pd.read_csv(FINAL_CSV_FILENAME)
|
||||
|
||||
# Helper to cast values (or None) given a dtype
|
||||
def _cast_value(val, dtype: str):
|
||||
if pd.isna(val):
|
||||
return None
|
||||
|
||||
if dtype == "text":
|
||||
return str(val).strip()
|
||||
|
||||
if dtype == "float":
|
||||
try:
|
||||
return float(val)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
if dtype == "bool":
|
||||
s = str(val).strip().lower()
|
||||
if s in ("true", "t", "yes", "1"):
|
||||
return True
|
||||
if s in ("false", "f", "no", "0"):
|
||||
return False
|
||||
if val in (1, 1.0):
|
||||
return True
|
||||
if val in (0, 0.0):
|
||||
return False
|
||||
return None
|
||||
|
||||
return val
|
||||
|
||||
try:
|
||||
rows_to_insert = []
|
||||
for _, row in df.iterrows():
|
||||
scenario = _cast_value(row.get("scenario"), "text")
|
||||
model_cls = _cast_value(row.get("model_cls"), "text")
|
||||
num_params_B = _cast_value(row.get("num_params_B"), "float")
|
||||
flops_G = _cast_value(row.get("flops_G"), "float")
|
||||
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
|
||||
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
|
||||
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
|
||||
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
|
||||
fullgraph = _cast_value(row.get("fullgraph"), "bool")
|
||||
mode = _cast_value(row.get("mode"), "text")
|
||||
|
||||
# If "github_sha" column exists in the CSV, cast it; else default to None
|
||||
if "github_sha" in df.columns:
|
||||
github_sha = _cast_value(row.get("github_sha"), "text")
|
||||
else:
|
||||
github_sha = None
|
||||
|
||||
measurements = {
|
||||
"scenario": scenario,
|
||||
"model_cls": model_cls,
|
||||
"num_params_B": num_params_B,
|
||||
"flops_G": flops_G,
|
||||
"time_plain_s": time_plain_s,
|
||||
"mem_plain_GB": mem_plain_GB,
|
||||
"time_compile_s": time_compile_s,
|
||||
"mem_compile_GB": mem_compile_GB,
|
||||
"fullgraph": fullgraph,
|
||||
"mode": mode,
|
||||
"github_sha": github_sha,
|
||||
}
|
||||
rows_to_insert.append((benchmark_id, measurements))
|
||||
|
||||
# Batch-insert all rows
|
||||
insert_sql = f"""
|
||||
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
|
||||
benchmark_id,
|
||||
measurements
|
||||
)
|
||||
VALUES (%s, %s);
|
||||
"""
|
||||
|
||||
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
|
||||
conn.commit()
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
except Exception as e:
|
||||
print(f"Exception: {e}")
|
||||
sys.exit(1)
|
||||
@@ -194,6 +194,8 @@
|
||||
title: Model accelerators and hardware
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/helios
|
||||
title: Helios
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
- local: using-diffusers/sdxl
|
||||
@@ -350,6 +352,8 @@
|
||||
title: FluxTransformer2DModel
|
||||
- local: api/models/glm_image_transformer2d
|
||||
title: GlmImageTransformer2DModel
|
||||
- local: api/models/helios_transformer3d
|
||||
title: HeliosTransformer3DModel
|
||||
- local: api/models/hidream_image_transformer
|
||||
title: HiDreamImageTransformer2DModel
|
||||
- local: api/models/hunyuan_transformer2d
|
||||
@@ -456,6 +460,8 @@
|
||||
title: AutoencoderKLQwenImage
|
||||
- local: api/models/autoencoder_kl_wan
|
||||
title: AutoencoderKLWan
|
||||
- local: api/models/autoencoder_rae
|
||||
title: AutoencoderRAE
|
||||
- local: api/models/consistency_decoder_vae
|
||||
title: ConsistencyDecoderVAE
|
||||
- local: api/models/autoencoder_oobleck
|
||||
@@ -625,8 +631,6 @@
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
@@ -675,6 +679,8 @@
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/helios
|
||||
title: Helios
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/hunyuan_video15
|
||||
@@ -746,6 +752,10 @@
|
||||
title: FlowMatchEulerDiscreteScheduler
|
||||
- local: api/schedulers/flow_match_heun_discrete
|
||||
title: FlowMatchHeunDiscreteScheduler
|
||||
- local: api/schedulers/helios_dmd
|
||||
title: HeliosDMDScheduler
|
||||
- local: api/schedulers/helios
|
||||
title: HeliosScheduler
|
||||
- local: api/schedulers/heun
|
||||
title: HeunDiscreteScheduler
|
||||
- local: api/schedulers/ipndm
|
||||
|
||||
@@ -23,6 +23,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
|
||||
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
|
||||
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
|
||||
- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).
|
||||
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
|
||||
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
|
||||
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
|
||||
@@ -86,6 +87,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
|
||||
|
||||
## HeliosLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin
|
||||
|
||||
## HunyuanVideoLoraLoaderMixin
|
||||
|
||||
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin
|
||||
|
||||
89
docs/source/en/api/models/autoencoder_rae.md
Normal file
89
docs/source/en/api/models/autoencoder_rae.md
Normal file
@@ -0,0 +1,89 @@
|
||||
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# AutoencoderRAE
|
||||
|
||||
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
|
||||
|
||||
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
|
||||
|
||||
The following RAE models are released and supported in Diffusers:
|
||||
|
||||
| Model | Encoder | Latent shape (224px input) |
|
||||
|:------|:--------|:---------------------------|
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
|
||||
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
|
||||
|
||||
## Loading a pretrained model
|
||||
|
||||
```python
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
```
|
||||
|
||||
## Encoding and decoding a real image
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.utils import load_image
|
||||
from torchvision.transforms.functional import to_tensor, to_pil_image
|
||||
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
|
||||
image = image.convert("RGB").resize((224, 224))
|
||||
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
|
||||
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # (1, 768, 16, 16)
|
||||
recon = model.decode(latents).sample # (1, 3, 256, 256)
|
||||
|
||||
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
|
||||
recon_image.save("recon.png")
|
||||
```
|
||||
|
||||
## Latent normalization
|
||||
|
||||
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
|
||||
|
||||
```python
|
||||
model = AutoencoderRAE.from_pretrained(
|
||||
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
|
||||
).to("cuda").eval()
|
||||
|
||||
# Latent normalization is handled automatically inside encode/decode
|
||||
# when the checkpoint config includes latents_mean/latents_std.
|
||||
with torch.no_grad():
|
||||
latents = model.encode(x).latent # normalized latents
|
||||
recon = model.decode(latents).sample
|
||||
```
|
||||
|
||||
## AutoencoderRAE
|
||||
|
||||
[[autodoc]] AutoencoderRAE
|
||||
- encode
|
||||
- decode
|
||||
- all
|
||||
|
||||
## DecoderOutput
|
||||
|
||||
[[autodoc]] models.autoencoders.vae.DecoderOutput
|
||||
35
docs/source/en/api/models/helios_transformer3d.md
Normal file
35
docs/source/en/api/models/helios_transformer3d.md
Normal file
@@ -0,0 +1,35 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License. -->
|
||||
|
||||
# HeliosTransformer3DModel
|
||||
|
||||
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
|
||||
|
||||
The model can be loaded with the following code snippet.
|
||||
|
||||
```python
|
||||
from diffusers import HeliosTransformer3DModel
|
||||
|
||||
# Best Quality
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
# Intermediate Weight
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
# Best Efficiency
|
||||
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16)
|
||||
```
|
||||
|
||||
## HeliosTransformer3DModel
|
||||
|
||||
[[autodoc]] HeliosTransformer3DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
|
||||
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput
|
||||
@@ -14,4 +14,8 @@
|
||||
|
||||
## AutoPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks
|
||||
@@ -46,6 +46,20 @@ output = pipe(
|
||||
output.save("output.png")
|
||||
```
|
||||
|
||||
## Cosmos2_5_TransferPipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_TransferPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
|
||||
## CosmosTextToWorldPipeline
|
||||
|
||||
[[autodoc]] CosmosTextToWorldPipeline
|
||||
@@ -70,12 +84,6 @@ output.save("output.png")
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## Cosmos2_5_PredictBasePipeline
|
||||
|
||||
[[autodoc]] Cosmos2_5_PredictBasePipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## CosmosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput
|
||||
|
||||
465
docs/source/en/api/pipelines/helios.md
Normal file
465
docs/source/en/api/pipelines/helios.md
Normal file
@@ -0,0 +1,465 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
<div style="float: right;">
|
||||
<div class="flex flex-wrap space-x-1">
|
||||
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
|
||||
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
# Helios
|
||||
|
||||
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
|
||||
|
||||
* <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).
|
||||
|
||||
The following Helios models are supported in Diffusers:
|
||||
|
||||
- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.
|
||||
- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.
|
||||
- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.
|
||||
|
||||
> [!TIP]
|
||||
> Click on the Helios models in the right sidebar for more examples of video generation.
|
||||
|
||||
### Optimizing Memory and Inference Speed
|
||||
|
||||
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
|
||||
|
||||
<hfoptions id="optimization">
|
||||
<hfoption id="memory">
|
||||
|
||||
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
|
||||
|
||||
The Helios model below requires ~19GB of VRAM.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.hooks.group_offloading import apply_group_offloading
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
# group-offloading
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.enable_group_offload(
|
||||
onload_device=torch.device("cuda"),
|
||||
offload_device=torch.device("cpu"),
|
||||
offload_type="block_level",
|
||||
num_blocks_per_group=1,
|
||||
use_stream=True,
|
||||
record_stream=True,
|
||||
)
|
||||
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="inference speed">
|
||||
|
||||
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.utils import export_to_video
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
# attention backend
|
||||
# pipeline.transformer.set_attention_backend("flash")
|
||||
pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs
|
||||
|
||||
# torch.compile
|
||||
torch.backends.cudnn.benchmark = True
|
||||
pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
|
||||
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Base
|
||||
|
||||
The example below demonstrates how to use Helios-Base to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Base usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Base",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=99,
|
||||
num_inference_steps=50,
|
||||
guidance_scale=5.0,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_base_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Mid
|
||||
|
||||
The example below demonstrates how to use Helios-Mid to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Mid usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPyramidPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPyramidPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Mid",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=99,
|
||||
pyramid_num_inference_steps_list=[20, 20, 20],
|
||||
guidance_scale=5.0,
|
||||
use_zero_init=True,
|
||||
zero_steps=1,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
### Generation with Helios-Distilled
|
||||
|
||||
The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.
|
||||
|
||||
<hfoptions id="Helios-Distilled usage">
|
||||
<hfoption id="usage">
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoModel, HeliosPyramidPipeline
|
||||
from diffusers.utils import export_to_video, load_video, load_image
|
||||
|
||||
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32)
|
||||
|
||||
pipeline = HeliosPyramidPipeline.from_pretrained(
|
||||
"BestWishYsh/Helios-Distilled",
|
||||
vae=vae,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipeline.to("cuda")
|
||||
|
||||
negative_prompt = """
|
||||
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
|
||||
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
|
||||
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
|
||||
"""
|
||||
|
||||
# For Text-to-Video
|
||||
prompt = """
|
||||
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
|
||||
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
|
||||
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
|
||||
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
|
||||
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
|
||||
the vivid colors of its surroundings. A close-up shot with dynamic movement.
|
||||
"""
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24)
|
||||
|
||||
# For Image-to-Video
|
||||
prompt = """
|
||||
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
|
||||
illuminating the intricate textures and deep green hues within the wave’s body. A thick spray erupts from the breaking crest,
|
||||
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
|
||||
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the ocean’s untamed beauty and
|
||||
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
|
||||
respect for nature’s might.
|
||||
"""
|
||||
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
image=load_image(image_path).resize((640, 384)),
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24)
|
||||
|
||||
# For Video-to-Video
|
||||
prompt = """
|
||||
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
|
||||
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
|
||||
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
|
||||
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
|
||||
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
|
||||
"""
|
||||
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
|
||||
|
||||
output = pipeline(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
video=load_video(video_path),
|
||||
num_frames=240,
|
||||
pyramid_num_inference_steps_list=[2, 2, 2],
|
||||
guidance_scale=1.0,
|
||||
is_amplify_first_chunk=True,
|
||||
generator=torch.Generator("cuda").manual_seed(42),
|
||||
).frames[0]
|
||||
export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
|
||||
## HeliosPipeline
|
||||
|
||||
[[autodoc]] HeliosPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HeliosPyramidPipeline
|
||||
|
||||
[[autodoc]] HeliosPyramidPipeline
|
||||
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## HeliosPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput
|
||||
@@ -29,7 +29,7 @@ Qwen-Image comes in the following variants:
|
||||
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
|
||||
|
||||
> [!TIP]
|
||||
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
|
||||
> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.
|
||||
|
||||
## LoRA for faster inference
|
||||
|
||||
@@ -190,6 +190,12 @@ For detailed benchmark scripts and results, see [this gist](https://gist.github.
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImageLayeredPipeline
|
||||
|
||||
[[autodoc]] QwenImageLayeredPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## QwenImagePipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput
|
||||
@@ -1,30 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
> [!WARNING]
|
||||
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
|
||||
|
||||
# K-Diffusion
|
||||
|
||||
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
|
||||
|
||||
Note that most the samplers from k-diffusion are implemented in Diffusers and we recommend using existing schedulers. You can find a mapping between k-diffusion samplers and schedulers in Diffusers [here](https://huggingface.co/docs/diffusers/api/schedulers/overview)
|
||||
|
||||
|
||||
## StableDiffusionKDiffusionPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionKDiffusionPipeline
|
||||
|
||||
|
||||
## StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
[[autodoc]] StableDiffusionXLKDiffusionPipeline
|
||||
20
docs/source/en/api/schedulers/helios.md
Normal file
20
docs/source/en/api/schedulers/helios.md
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HeliosScheduler
|
||||
|
||||
`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
|
||||
|
||||
## HeliosScheduler
|
||||
[[autodoc]] HeliosScheduler
|
||||
|
||||
scheduling_helios
|
||||
20
docs/source/en/api/schedulers/helios_dmd.md
Normal file
20
docs/source/en/api/schedulers/helios_dmd.md
Normal file
@@ -0,0 +1,20 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# HeliosDMDScheduler
|
||||
|
||||
`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
|
||||
|
||||
## HeliosDMDScheduler
|
||||
[[autodoc]] HeliosDMDScheduler
|
||||
|
||||
scheduling_helios_dmd
|
||||
@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# List of sub-block classes to choose from
|
||||
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
# Names for each block in the same order
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
# Trigger inputs that determine which block to run
|
||||
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
# - "image" triggers img2img workflow (but only if mask is not provided)
|
||||
# - if none of above, runs the text2img workflow (default)
|
||||
block_trigger_inputs = ["mask", "image", None]
|
||||
# Description is extremely important for AutoPipelineBlocks
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
|
||||
)
|
||||
```
|
||||
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
|
||||
|
||||
Create an instance of `AutoImageBlocks`.
|
||||
|
||||
@@ -152,5 +152,74 @@ auto_blocks = AutoImageBlocks()
|
||||
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
|
||||
|
||||
```py
|
||||
auto_blocks.get_execution_blocks("mask")
|
||||
auto_blocks.get_execution_blocks(mask=True)
|
||||
```
|
||||
|
||||
## ConditionalPipelineBlocks
|
||||
|
||||
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
|
||||
|
||||
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import ConditionalPipelineBlocks
|
||||
|
||||
class AutoImageBlocks(ConditionalPipelineBlocks):
|
||||
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask", "image"]
|
||||
default_block_name = "text2img"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Pipeline generates images given different types of conditions!\n"
|
||||
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
|
||||
+ " - inpaint workflow is run when `mask` is provided.\n"
|
||||
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
|
||||
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
|
||||
)
|
||||
|
||||
def select_block(self, mask=None, image=None) -> str | None:
|
||||
if mask is not None:
|
||||
return "inpaint"
|
||||
if image is not None:
|
||||
return "img2img"
|
||||
return None # falls back to default_block_name ("text2img")
|
||||
```
|
||||
|
||||
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
|
||||
|
||||
## Workflows
|
||||
|
||||
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
|
||||
|
||||
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class MyPipelineBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
|
||||
block_names = ["text_encoder", "auto_image", "decode"]
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
"inpaint": {"mask": True, "image": True, "prompt": True},
|
||||
}
|
||||
```
|
||||
|
||||
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
|
||||
|
||||
```py
|
||||
pipeline_blocks = MyPipelineBlocks()
|
||||
pipeline_blocks.available_workflows
|
||||
# ['text2image', 'image2image', 'inpaint']
|
||||
```
|
||||
|
||||
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
|
||||
|
||||
```py
|
||||
pipeline_blocks.get_workflow("inpaint")
|
||||
```
|
||||
@@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
|
||||
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
</hfoptions>
|
||||
|
||||
## Dependencies
|
||||
|
||||
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
|
||||
|
||||
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import PipelineBlock
|
||||
|
||||
class MyCustomBlock(PipelineBlock):
|
||||
_requirements = {
|
||||
"transformers": ">=4.44.0",
|
||||
"sentencepiece": ">=0.2.0"
|
||||
}
|
||||
```
|
||||
|
||||
When there are blocks with different requirements, Diffusers merges their requirements.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
|
||||
class BlockA(PipelineBlock):
|
||||
_requirements = {"transformers": ">=4.44.0"}
|
||||
# ...
|
||||
|
||||
class BlockB(PipelineBlock):
|
||||
_requirements = {"sentencepiece": ">=0.2.0"}
|
||||
# ...
|
||||
|
||||
pipe = SequentialPipelineBlocks.from_blocks_dict({
|
||||
"block_a": BlockA,
|
||||
"block_b": BlockB,
|
||||
})
|
||||
```
|
||||
|
||||
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
|
||||
|
||||
```md
|
||||
# missing package
|
||||
xyz-package was specified in the requirements but wasn't found in the current environment.
|
||||
|
||||
# version mismatch
|
||||
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
|
||||
```
|
||||
|
||||
@@ -89,10 +89,8 @@ t2i_pipeline.guider
|
||||
|
||||
## Changing guider parameters
|
||||
|
||||
The guider parameters can be adjusted with either the [`~ComponentSpec.create`] method or with [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
|
||||
The guider parameters can be adjusted with the [`~ComponentSpec.create`] method and [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
|
||||
|
||||
<hfoptions id="switch">
|
||||
<hfoption id="create">
|
||||
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
@@ -100,18 +98,6 @@ guider = guider_spec.create(guidance_scale=10)
|
||||
t2i_pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="update_components">
|
||||
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.config["guidance_scale"] = 10
|
||||
t2i_pipeline.update_components(guider=guider_spec)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Uploading custom guiders
|
||||
|
||||
Call the [`~utils.PushToHubMixin.push_to_hub`] method on a custom guider to share it to the Hub.
|
||||
|
||||
@@ -25,9 +25,7 @@ This guide explains how states work and how they connect blocks.
|
||||
|
||||
The [`~modular_pipelines.PipelineState`] is a global state container for all blocks. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
|
||||
|
||||
There are two dict's in [`~modular_pipelines.PipelineState`] for structuring data.
|
||||
|
||||
- The `values` dict is a **mutable** state containing a copy of user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
|
||||
[`~modular_pipelines.PipelineState`] stores all data in a `values` dict, which is a **mutable** state containing user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
|
||||
|
||||
```py
|
||||
PipelineState(
|
||||
|
||||
@@ -12,27 +12,28 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ModularPipeline
|
||||
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
|
||||
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`] into an executable pipeline that loads models and performs the computation steps defined in the blocks. It is the main interface for running a pipeline and the API is very similar to [`DiffusionPipeline`] but with a few key differences.
|
||||
|
||||
The main difference is to include an expected `output` argument in the pipeline.
|
||||
- **Loading is lazy.** With [`DiffusionPipeline`], [`~DiffusionPipeline.from_pretrained`] creates the pipeline and loads all models at the same time. With [`ModularPipeline`], creating and loading are two separate steps: [`~ModularPipeline.from_pretrained`] reads the configuration and knows where to load each component from, but doesn't actually load the model weights. You load the models later with [`~ModularPipeline.load_components`], which is where you pass loading arguments like `torch_dtype` and `quantization_config`.
|
||||
|
||||
- **Two ways to create a pipeline.** You can use [`~ModularPipeline.from_pretrained`] with an existing diffusers model repository — it automatically maps to the default pipeline blocks and then converts to a [`ModularPipeline`] with no extra setup. You can check the [modular_pipelines_directory](https://github.com/huggingface/diffusers/tree/main/src/diffusers/modular_pipelines) to see which models are currently supported. You can also assemble your own pipeline from [`ModularPipelineBlocks`] and convert it with the [`~ModularPipelineBlocks.init_pipeline`] method (see [Creating a pipeline](#creating-a-pipeline) for more details).
|
||||
|
||||
- **Running the pipeline is the same.** Once loaded, you call the pipeline with the same arguments you're used to. A single [`ModularPipeline`] can support multiple workflows (text-to-image, image-to-image, inpainting, etc.) when the pipeline blocks use [`AutoPipelineBlocks`](./auto_pipeline_blocks) to automatically select the workflow based on your inputs.
|
||||
|
||||
Below are complete examples for text-to-image, image-to-image, and inpainting with SDXL.
|
||||
|
||||
<hfoptions id="example">
|
||||
<hfoption id="text-to-image">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
|
||||
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
|
||||
image.save("modular_t2i_out.png")
|
||||
```
|
||||
|
||||
@@ -41,21 +42,17 @@ image.save("modular_t2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
|
||||
init_image = load_image(url)
|
||||
prompt = "a dog catching a frisbee in the jungle"
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
|
||||
image = pipeline(prompt=prompt, image=init_image, strength=0.8).images[0]
|
||||
image.save("modular_i2i_out.png")
|
||||
```
|
||||
|
||||
@@ -64,15 +61,10 @@ image.save("modular_i2i_out.png")
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
|
||||
from diffusers import ModularPipeline
|
||||
from diffusers.utils import load_image
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
pipeline = blocks.init_pipeline(modular_repo_id)
|
||||
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
pipeline.to("cuda")
|
||||
|
||||
@@ -83,276 +75,353 @@ init_image = load_image(img_url)
|
||||
mask_image = load_image(mask_url)
|
||||
|
||||
prompt = "A deep sea diver floating"
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
|
||||
image.save("moduar_inpaint_out.png")
|
||||
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85).images[0]
|
||||
image.save("modular_inpaint_out.png")
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
|
||||
|
||||
## Adding blocks
|
||||
|
||||
Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
|
||||
|
||||
```py
|
||||
# BLOCKS is dict of block classes, you need to add class to it
|
||||
BLOCKS.insert("block_name", BlockClass, index)
|
||||
# sub_blocks attribute contains instance, add a block instance to the attribute
|
||||
t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
|
||||
```
|
||||
|
||||
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
|
||||
|
||||
```py
|
||||
# remove a block class from preset
|
||||
BLOCKS.pop("text_encoder")
|
||||
# split out a block instance on its own
|
||||
text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
|
||||
```
|
||||
|
||||
Swap blocks by setting the existing block to the new block.
|
||||
|
||||
```py
|
||||
# Replace block class in preset
|
||||
BLOCKS["prepare_latents"] = CustomPrepareLatents
|
||||
# Replace in sub_blocks attribute using an block instance
|
||||
t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
|
||||
```
|
||||
This guide will show you how to create a [`ModularPipeline`], manage its components, and run the pipeline.
|
||||
|
||||
## Creating a pipeline
|
||||
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] with [`~ModularPipelineBlocks.init_pipeline`], or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
|
||||
You can also initialize a [`ComponentsManager`](./components_manager) to handle device placement and memory management. If you don't need automatic offloading, you can skip this and move the pipeline to your device manually with `pipeline.to("cuda")`.
|
||||
|
||||
> [!TIP]
|
||||
> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
|
||||
|
||||
<hfoptions id="create">
|
||||
<hfoption id="ModularPipelineBlocks">
|
||||
### init_pipeline
|
||||
|
||||
Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
|
||||
[`~ModularPipelineBlocks.init_pipeline`] converts any [`ModularPipelineBlocks`] into a [`ModularPipeline`].
|
||||
|
||||
Let's define a minimal block to see how it works:
|
||||
|
||||
```py
|
||||
from diffusers import ComponentsManager
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks
|
||||
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
|
||||
from transformers import CLIPTextModel
|
||||
from diffusers.modular_pipelines import (
|
||||
ComponentSpec,
|
||||
ModularPipelineBlocks,
|
||||
PipelineState,
|
||||
)
|
||||
|
||||
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
|
||||
class MyBlock(ModularPipelineBlocks):
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(
|
||||
name="text_encoder",
|
||||
type_hint=CLIPTextModel,
|
||||
pretrained_model_name_or_path="openai/clip-vit-large-patch14",
|
||||
),
|
||||
]
|
||||
|
||||
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
|
||||
components = ComponentsManager()
|
||||
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
|
||||
def __call__(self, components, state: PipelineState) -> PipelineState:
|
||||
return components, state
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="from_pretrained">
|
||||
Call [`~ModularPipelineBlocks.init_pipeline`] to convert it into a pipeline. The `blocks` attribute on the pipeline is the blocks it was created from — it determines the expected inputs, outputs, and computation logic.
|
||||
|
||||
The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
|
||||
```py
|
||||
block = MyBlock()
|
||||
pipe = block.init_pipeline()
|
||||
pipe.blocks
|
||||
```
|
||||
|
||||
```
|
||||
MyBlock {
|
||||
"_class_name": "MyBlock",
|
||||
"_diffusers_version": "0.37.0.dev0"
|
||||
}
|
||||
```
|
||||
|
||||
> [!WARNING]
|
||||
> Blocks are mutable — you can freely add, remove, or swap blocks before creating a pipeline. However, once a pipeline is created, modifying `pipeline.blocks` won't affect the pipeline because it returns a copy. If you want a different block structure, create a new pipeline after modifying the blocks.
|
||||
|
||||
When you call [`~ModularPipelineBlocks.init_pipeline`] without a repository, it uses the `pretrained_model_name_or_path` defined in the block's [`ComponentSpec`] to determine where to load each component from. Printing the pipeline shows the component loading configuration.
|
||||
|
||||
```py
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "openai/clip-vit-large-patch14",
|
||||
"revision": null,
|
||||
"subfolder": "",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If you pass a repository to [`~ModularPipelineBlocks.init_pipeline`], it overrides the loading path by matching your block's components against the pipeline config in that repository (`model_index.json` or `modular_model_index.json`).
|
||||
|
||||
In the example below, the `pretrained_model_name_or_path` will be updated to `"stabilityai/stable-diffusion-xl-base-1.0"`.
|
||||
|
||||
```py
|
||||
pipe = block.init_pipeline("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
pipe
|
||||
ModularPipeline {
|
||||
"_blocks_class_name": "MyBlock",
|
||||
"_class_name": "ModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"text_encoder": [
|
||||
null,
|
||||
null,
|
||||
{
|
||||
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"revision": null,
|
||||
"subfolder": "text_encoder",
|
||||
"type_hint": [
|
||||
"transformers",
|
||||
"CLIPTextModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
If a component in your block doesn't exist in the repository, it remains `null` and is skipped during [`~ModularPipeline.load_components`].
|
||||
|
||||
### from_pretrained
|
||||
|
||||
[`~ModularPipeline.from_pretrained`] is a convenient way to create a [`ModularPipeline`] without defining blocks yourself.
|
||||
|
||||
It works with three types of repositories.
|
||||
|
||||
**A regular diffusers repository.** Pass any supported model repository and it automatically maps to the default pipeline blocks. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
|
||||
**A modular repository.** These repositories contain a `modular_model_index.json` that specifies where to load each component from — the components can come from different repositories and the modular repository itself may not contain any model weights. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from one repository and the remaining components from another. See [Modular repository](#modular-repository) for more details on the format.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
modular_repo_id = "YiYiXu/modular-diffdiff-0704"
|
||||
diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/flux2-bnb-4bit-modular", components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
**A modular repository with custom code.** Some repositories include custom pipeline blocks alongside the loading configuration. Add `trust_remote_code=True` to load them. See [Custom blocks](./custom_blocks) for how to create your own.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
|
||||
components = ComponentsManager()
|
||||
pipeline = ModularPipeline.from_pretrained(
|
||||
"diffusers/Florence2-image-Annotator", trust_remote_code=True, components_manager=components
|
||||
)
|
||||
```
|
||||
|
||||
## Loading components
|
||||
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
|
||||
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load components with [`~ModularPipeline.load_components`].
|
||||
|
||||
<hfoptions id="load">
|
||||
<hfoption id="load_components">
|
||||
This will load all the components that have a valid loading spec.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(torch_dtype=torch.float16)
|
||||
t2i_pipeline.to("cuda")
|
||||
pipeline.load_components(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="load_components">
|
||||
|
||||
The example below only loads the UNet and VAE.
|
||||
You can also load specific components by name. The example below only loads the `text_encoder`.
|
||||
|
||||
```py
|
||||
import torch
|
||||
|
||||
t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
|
||||
pipeline.load_components(names=["text_encoder"], torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Print the pipeline to inspect the loaded pretrained components.
|
||||
After loading, printing the pipeline shows which components are loaded — the first two fields change from `null` to the component's library and class.
|
||||
|
||||
```py
|
||||
t2i_pipeline
|
||||
pipeline
|
||||
```
|
||||
|
||||
This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
|
||||
|
||||
To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
|
||||
|
||||
```json
|
||||
# original
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
```
|
||||
# text_encoder is loaded - shows library and class
|
||||
"text_encoder": [
|
||||
"transformers",
|
||||
"CLIPTextModel",
|
||||
{ ... }
|
||||
]
|
||||
|
||||
# modified
|
||||
# unet is not loaded yet - still null
|
||||
"unet": [
|
||||
null, null,
|
||||
{
|
||||
"repo": "RunDiffusion/Juggernaut-XL-v9",
|
||||
"subfolder": "unet",
|
||||
"variant": "fp16"
|
||||
}
|
||||
null,
|
||||
null,
|
||||
{ ... }
|
||||
]
|
||||
```
|
||||
|
||||
### Component loading status
|
||||
|
||||
The pipeline properties below provide more information about which components are loaded.
|
||||
|
||||
Use `component_names` to return all expected components.
|
||||
Loading keyword arguments like `torch_dtype`, `variant`, `revision`, and `quantization_config` are passed through to `from_pretrained()` for each component. You can pass a single value to apply to all components, or a dict to set per-component values.
|
||||
|
||||
```py
|
||||
t2i_pipeline.component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
|
||||
# apply bfloat16 to all components
|
||||
pipeline.load_components(torch_dtype=torch.bfloat16)
|
||||
|
||||
# different dtypes per component
|
||||
pipeline.load_components(torch_dtype={"transformer": torch.bfloat16, "default": torch.float32})
|
||||
```
|
||||
|
||||
Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
t2i_pipeline.null_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
|
||||
```
|
||||
|
||||
Use `pretrained_component_names` to return components that will be loaded from pretrained models.
|
||||
|
||||
```py
|
||||
t2i_pipeline.pretrained_component_names
|
||||
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
|
||||
```
|
||||
|
||||
Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
|
||||
|
||||
```py
|
||||
t2i_pipeline.config_component_names
|
||||
['guider', 'image_processor']
|
||||
```
|
||||
[`~ModularPipeline.load_components`] only loads components that haven't been loaded yet and have a valid loading spec. This means if you've already set a component on the pipeline, calling [`~ModularPipeline.load_components`] again won't reload it.
|
||||
|
||||
## Updating components
|
||||
|
||||
Components may be updated depending on whether it is a *pretrained component* or a *config component*.
|
||||
[`~ModularPipeline.update_components`] replaces a component on the pipeline with a new one. When a component is updated, the loading specifications are also updated in the pipeline config and [`~ModularPipeline.load_components`] will skip it on subsequent calls.
|
||||
|
||||
> [!WARNING]
|
||||
> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
|
||||
### From AutoModel
|
||||
|
||||
A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
|
||||
|
||||
The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
|
||||
|
||||
To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
|
||||
You can pass a model object loaded with `AutoModel.from_pretrained()`. Models loaded this way are automatically tagged with their loading information.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, UNet2DConditionModel
|
||||
from diffusers import AutoModel
|
||||
|
||||
unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.update_components`] method replaces the component with a new one.
|
||||
|
||||
```py
|
||||
t2i_pipeline.update_components(unet=unet2)
|
||||
```
|
||||
|
||||
When a component is updated, the loading specifications are also updated in the pipeline config.
|
||||
|
||||
### Component extraction and modification
|
||||
|
||||
When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
|
||||
|
||||
```py
|
||||
spec = ComponentSpec.from_component("unet", unet2)
|
||||
spec
|
||||
ComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
|
||||
unet2_recreated = spec.load(torch_dtype=torch.float16)
|
||||
```
|
||||
|
||||
The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
|
||||
|
||||
```py
|
||||
unet_spec = t2i_pipeline.get_component_spec("unet")
|
||||
unet_spec
|
||||
ComponentSpec(
|
||||
name='unet',
|
||||
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
|
||||
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
|
||||
subfolder='unet',
|
||||
variant='fp16',
|
||||
default_creation_method='from_pretrained'
|
||||
unet = AutoModel.from_pretrained(
|
||||
"RunDiffusion/Juggernaut-XL-v9", subfolder="unet", variant="fp16", torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
### From ComponentSpec
|
||||
|
||||
Use [`~ModularPipeline.get_component_spec`] to get a copy of the current component specification, modify it, and load a new component.
|
||||
|
||||
```py
|
||||
unet_spec = pipeline.get_component_spec("unet")
|
||||
|
||||
# modify to load from a different repository
|
||||
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
||||
unet_spec.pretrained_model_name_or_path = "RunDiffusion/Juggernaut-XL-v9"
|
||||
|
||||
# load component with modified spec
|
||||
# load and update
|
||||
unet = unet_spec.load(torch_dtype=torch.float16)
|
||||
pipeline.update_components(unet=unet)
|
||||
```
|
||||
|
||||
You can also create a [`ComponentSpec`] from scratch.
|
||||
|
||||
Not all components are loaded from pretrained weights — some are created from a config (listed under `pipeline.config_component_names`). For these, use [`~ComponentSpec.create`] instead of [`~ComponentSpec.load`].
|
||||
|
||||
```py
|
||||
guider_spec = pipeline.get_component_spec("guider")
|
||||
guider_spec.config = {"guidance_scale": 5.0}
|
||||
guider = guider_spec.create()
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
Or simply pass the object directly.
|
||||
|
||||
```py
|
||||
from diffusers.guiders import ClassifierFreeGuidance
|
||||
|
||||
guider = ClassifierFreeGuidance(guidance_scale=5.0)
|
||||
pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
|
||||
|
||||
## Splitting a pipeline into stages
|
||||
|
||||
Since blocks are composable, you can take a pipeline apart and reconstruct it into separate pipelines for each stage. The example below shows how we can separate the text encoder block from the rest of the pipeline, so you can encode the prompt independently and pass the embeddings to the main pipeline.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline, ComponentsManager
|
||||
import torch
|
||||
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
repo_id = "black-forest-labs/FLUX.2-klein-4B"
|
||||
|
||||
# get the blocks and separate out the text encoder
|
||||
blocks = ModularPipeline.from_pretrained(repo_id).blocks
|
||||
text_block = blocks.sub_blocks.pop("text_encoder")
|
||||
|
||||
# use ComponentsManager to handle offloading across multiple pipelines
|
||||
manager = ComponentsManager()
|
||||
manager.enable_auto_cpu_offload(device=device)
|
||||
|
||||
# create separate pipelines for each stage
|
||||
text_encoder_pipeline = text_block.init_pipeline(repo_id, components_manager=manager)
|
||||
pipeline = blocks.init_pipeline(repo_id, components_manager=manager)
|
||||
|
||||
# encode text
|
||||
text_encoder_pipeline.load_components(torch_dtype=dtype)
|
||||
text_embeddings = text_encoder_pipeline(prompt="a cat").get_by_kwargs("denoiser_input_fields")
|
||||
|
||||
# denoise and decode
|
||||
pipeline.load_components(torch_dtype=dtype)
|
||||
output = pipeline(
|
||||
**text_embeddings,
|
||||
num_inference_steps=4,
|
||||
).images[0]
|
||||
```
|
||||
|
||||
[`ComponentsManager`] handles memory across multiple pipelines. Unlike the offloading strategies in [`DiffusionPipeline`] that follow a fixed order, [`ComponentsManager`] makes offloading decisions dynamically each time a model forward pass runs, based on the current memory situation. This means it works regardless of how many pipelines you create or what order you run them in. See the [ComponentsManager](./components_manager) guide for more details.
|
||||
|
||||
If pipeline stages share components (e.g., the same VAE used for encoding and decoding), you can use [`~ModularPipeline.update_components`] to pass an already-loaded component to another pipeline instead of loading it again.
|
||||
|
||||
## Modular repository
|
||||
|
||||
A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
|
||||
|
||||
[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
[`ModularPipeline`] works with regular diffusers repositories out of the box. However, you can also create a *modular repository* for more flexibility. A modular repository contains a `modular_model_index.json` file containing the following 3 elements.
|
||||
|
||||
- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
|
||||
- `library` and `class` shows which library the component was loaded from and its class. If `null`, the component hasn't been loaded yet.
|
||||
- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
|
||||
|
||||
Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
|
||||
The key advantage of a modular repository is that components can be loaded from different repositories. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from `diffusers/FLUX.2-dev-bnb-4bit` while loading the remaining components from `black-forest-labs/FLUX.2-dev`.
|
||||
|
||||
A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
|
||||
To convert a regular diffusers repository into a modular one, create the pipeline using the regular repository, and then push to the Hub. The saved repository will contain a `modular_model_index.json` with all the loading specifications.
|
||||
|
||||
```py
|
||||
from diffusers import ModularPipeline
|
||||
|
||||
# load from a regular repo
|
||||
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
|
||||
|
||||
# push as a modular repository
|
||||
pipeline.save_pretrained("local/path", repo_id="my-username/sdxl-modular", push_to_hub=True)
|
||||
```
|
||||
|
||||
A modular repository can also include custom pipeline blocks as Python code. This allows you to share specialized blocks that aren't native to Diffusers. For example, [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator) contains custom blocks alongside the loading configuration:
|
||||
|
||||
```
|
||||
modular-diffdiff-0704/
|
||||
Florence2-image-Annotator/
|
||||
├── block.py # Custom pipeline blocks implementation
|
||||
├── config.json # Pipeline configuration and auto_map
|
||||
├── mellon_config.json # UI configuration for Mellon
|
||||
└── modular_model_index.json # Component loading specifications
|
||||
```
|
||||
|
||||
The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
|
||||
The `config.json` file contains an `auto_map` key that tells [`ModularPipeline`] where to find the custom blocks:
|
||||
|
||||
```json
|
||||
{
|
||||
"_class_name": "DiffDiffBlocks",
|
||||
"_class_name": "Florence2AnnotatorBlocks",
|
||||
"auto_map": {
|
||||
"ModularPipelineBlocks": "block.DiffDiffBlocks"
|
||||
"ModularPipelineBlocks": "block.Florence2AnnotatorBlocks"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Load custom code repositories with `trust_remote_code=True` as shown in [from_pretrained](#from_pretrained). See [Custom blocks](./custom_blocks) for how to create and share your own.
|
||||
@@ -25,56 +25,42 @@ This guide will show you how to create a [`~modular_pipelines.ModularPipelineBlo
|
||||
|
||||
A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermediate_outputs`.
|
||||
|
||||
- `inputs` are values provided by a user and retrieved from the [`~modular_pipelines.PipelineState`]. This is useful because some workflows resize an image, but the original image is still required. The [`~modular_pipelines.PipelineState`] maintains the original image.
|
||||
- `inputs` are values a block reads from the [`~modular_pipelines.PipelineState`] to perform its computation. These can be values provided by a user (like a prompt or image) or values produced by a previous block (like encoded `image_latents`).
|
||||
|
||||
Use `InputParam` to define `inputs`.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import InputParam
|
||||
```py
|
||||
class ImageEncodeStep(ModularPipelineBlocks):
|
||||
...
|
||||
|
||||
user_inputs = [
|
||||
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
|
||||
]
|
||||
```
|
||||
@property
|
||||
def inputs(self):
|
||||
return [
|
||||
InputParam(name="image", type_hint="PIL.Image", required=True, description="raw input image to process"),
|
||||
]
|
||||
...
|
||||
```
|
||||
|
||||
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
|
||||
|
||||
Use `OutputParam` to define `intermediate_outputs`.
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import OutputParam
|
||||
```py
|
||||
class ImageEncodeStep(ModularPipelineBlocks):
|
||||
...
|
||||
|
||||
user_intermediate_outputs = [
|
||||
OutputParam(name="image_latents", description="latents representing the image")
|
||||
]
|
||||
```
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return [
|
||||
OutputParam(name="image_latents", description="latents representing the image"),
|
||||
]
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
The intermediate inputs and outputs share data to connect blocks. They are accessible at any point, allowing you to track the workflow's progress.
|
||||
|
||||
## Computation logic
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
```py
|
||||
def __call__(self, components, state):
|
||||
# Get a local view of the state variables this block needs
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
### Components and configs
|
||||
## Components and configs
|
||||
|
||||
The components and pipeline-level configs a block needs are specified in [`ComponentSpec`] and [`~modular_pipelines.ConfigSpec`].
|
||||
|
||||
@@ -82,24 +68,108 @@ The components and pipeline-level configs a block needs are specified in [`Compo
|
||||
- [`~modular_pipelines.ConfigSpec`] contains pipeline-level settings that control behavior across all blocks.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, ConfigSpec
|
||||
class ImageEncodeStep(ModularPipelineBlocks):
|
||||
...
|
||||
|
||||
expected_components = [
|
||||
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
|
||||
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
|
||||
]
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(name="vae", type_hint=AutoencoderKL),
|
||||
]
|
||||
|
||||
expected_config = [
|
||||
ConfigSpec("force_zeros_for_empty_prompt", True)
|
||||
]
|
||||
@property
|
||||
def expected_configs(self):
|
||||
return [
|
||||
ConfigSpec("force_zeros_for_empty_prompt", True),
|
||||
]
|
||||
|
||||
...
|
||||
```
|
||||
|
||||
When the blocks are converted into a pipeline, the components become available to the block as the first argument in `__call__`.
|
||||
|
||||
## Computation logic
|
||||
|
||||
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
|
||||
|
||||
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`.
|
||||
2. Implement the computation logic on the `inputs`.
|
||||
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
|
||||
4. Return the components and state which becomes available to the next block.
|
||||
|
||||
```py
|
||||
def __call__(self, components, state):
|
||||
# Access components using dot notation
|
||||
unet = components.unet
|
||||
vae = components.vae
|
||||
scheduler = components.scheduler
|
||||
class ImageEncodeStep(ModularPipelineBlocks):
|
||||
|
||||
def __call__(self, components, state):
|
||||
# Get a local view of the state variables this block needs
|
||||
block_state = self.get_block_state(state)
|
||||
|
||||
# Your computation logic here
|
||||
# block_state contains all your inputs
|
||||
# Access them like: block_state.image, block_state.processed_image
|
||||
|
||||
# Update the pipeline state with your updated block_states
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
## Putting it all together
|
||||
|
||||
Here is the complete block with all the pieces connected.
|
||||
|
||||
```py
|
||||
from diffusers import ComponentSpec, AutoencoderKL
|
||||
from diffusers.modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam
|
||||
|
||||
|
||||
class ImageEncodeStep(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Encode an image into latent space."
|
||||
|
||||
@property
|
||||
def expected_components(self):
|
||||
return [
|
||||
ComponentSpec(name="vae", type_hint=AutoencoderKL),
|
||||
]
|
||||
|
||||
@property
|
||||
def inputs(self):
|
||||
return [
|
||||
InputParam(name="image", type_hint="PIL.Image", required=True, description="raw input image to process"),
|
||||
]
|
||||
|
||||
@property
|
||||
def intermediate_outputs(self):
|
||||
return [
|
||||
OutputParam(name="image_latents", type_hint="torch.Tensor", description="latents representing the image"),
|
||||
]
|
||||
|
||||
def __call__(self, components, state):
|
||||
block_state = self.get_block_state(state)
|
||||
block_state.image_latents = components.vae.encode(block_state.image)
|
||||
self.set_block_state(state, block_state)
|
||||
return components, state
|
||||
```
|
||||
|
||||
Every block has a `doc` property that is automatically generated from the properties you defined above. It provides a summary of the block's description, components, inputs, and outputs.
|
||||
|
||||
```py
|
||||
block = ImageEncoderStep()
|
||||
print(block.doc)
|
||||
class ImageEncodeStep
|
||||
|
||||
Encode an image into latent space.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKL`)
|
||||
|
||||
Inputs:
|
||||
image (`PIL.Image`):
|
||||
raw input image to process
|
||||
|
||||
Outputs:
|
||||
image_latents (`torch.Tensor`):
|
||||
latents representing the image
|
||||
```
|
||||
@@ -39,17 +39,44 @@ image
|
||||
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
|
||||
|
||||
> [!TIP]
|
||||
> [`ComponentsManager`] with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
|
||||
> `ComponentsManager` with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
|
||||
>
|
||||
> If you don't need offloading, remove the `components_manager` argument and move the pipeline to your device manually with `to("cuda")`.
|
||||
|
||||
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
|
||||
|
||||
## Understand the structure
|
||||
|
||||
A [`ModularPipeline`] has two parts:
|
||||
- **State**: the loaded components (models, schedulers, processors) and configuration
|
||||
- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic
|
||||
A [`ModularPipeline`] has two parts: a **definition** (the blocks) and a **state** (the loaded components and configs).
|
||||
|
||||
The blocks define *what* the pipeline does. Access them through `pipe.blocks`.
|
||||
Print the pipeline to see its state — the components and their loading status and configuration.
|
||||
```py
|
||||
print(pipe)
|
||||
```
|
||||
```
|
||||
QwenImageModularPipeline {
|
||||
"_blocks_class_name": "QwenImageAutoBlocks",
|
||||
"_class_name": "QwenImageModularPipeline",
|
||||
"_diffusers_version": "0.37.0.dev0",
|
||||
"transformer": [
|
||||
"diffusers",
|
||||
"QwenImageTransformer2DModel",
|
||||
{
|
||||
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
|
||||
"revision": null,
|
||||
"subfolder": "transformer",
|
||||
"type_hint": [
|
||||
"diffusers",
|
||||
"QwenImageTransformer2DModel"
|
||||
],
|
||||
"variant": null
|
||||
}
|
||||
],
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
Access the definition through `pipe.blocks` — this is the [`~modular_pipelines.ModularPipelineBlocks`] that defines the pipeline's workflows, inputs, outputs, and computation logic.
|
||||
```py
|
||||
print(pipe.blocks)
|
||||
```
|
||||
@@ -87,7 +114,8 @@ The output returns:
|
||||
|
||||
### Workflows
|
||||
|
||||
`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Let's see this in action with an example.
|
||||
This pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Learn more about how this works under the hood in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
|
||||
|
||||
```py
|
||||
from diffusers.utils import load_image
|
||||
|
||||
@@ -99,20 +127,21 @@ image = pipe(
|
||||
).images[0]
|
||||
```
|
||||
|
||||
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow.
|
||||
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow. This is useful when you want to customize or debug a specific workflow. You can check `pipe.blocks.available_workflows` to see all available workflows.
|
||||
```py
|
||||
img2img_blocks = pipe.blocks.get_workflow("image2image")
|
||||
```
|
||||
|
||||
Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
|
||||
|
||||
### Sub-blocks
|
||||
|
||||
Blocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it.
|
||||
|
||||
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
|
||||
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`.
|
||||
|
||||
The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
|
||||
These sub-blocks run one after another and data flows linearly from one block to the next — each block's `intermediate_outputs` become available as `inputs` to the next block. This is how [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) work.
|
||||
|
||||
You can access them through the `sub_blocks` property. The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
|
||||
```py
|
||||
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
|
||||
print(vae_encoder_block.doc)
|
||||
@@ -165,7 +194,7 @@ class CannyBlock
|
||||
Canny map for input image
|
||||
```
|
||||
|
||||
UUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
|
||||
Use `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
|
||||
```py
|
||||
# Get the controlnet workflow that we want to work with
|
||||
blocks = pipe.blocks.get_workflow("controlnet_text2image")
|
||||
@@ -182,9 +211,8 @@ class SequentialPipelineBlocks
|
||||
...
|
||||
```
|
||||
|
||||
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) - a multi-block type where blocks run one after another and data flows linearly from one block to the next. Each block's `intermediate_outputs` become available as `inputs` to subsequent blocks.
|
||||
|
||||
Currently this workflow requires `control_image` as input. Let's insert the canny block at the beginning so the pipeline accepts a regular image instead.
|
||||
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) and it currently requires `control_image` as input. Insert the canny block at the beginning so the pipeline accepts a regular image instead.
|
||||
```py
|
||||
# Insert canny at the beginning
|
||||
blocks.sub_blocks.insert("canny", canny_block, 0)
|
||||
@@ -211,7 +239,7 @@ class SequentialPipelineBlocks
|
||||
|
||||
Now the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it.
|
||||
|
||||
Create a pipeline from the modified blocks and load a ControlNet model.
|
||||
Create a pipeline from the modified blocks and load a ControlNet model. The ControlNet isn't part of the original model repository, so load it separately and add it with [`~ModularPipeline.update_components`].
|
||||
```py
|
||||
pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)
|
||||
|
||||
@@ -241,6 +269,16 @@ output
|
||||
## Next steps
|
||||
|
||||
<hfoptions id="next">
|
||||
<hfoption id="Learn the basics">
|
||||
|
||||
Understand the core building blocks of Modular Diffusers:
|
||||
|
||||
- [ModularPipelineBlocks](./pipeline_block): The basic unit for defining a step in a pipeline.
|
||||
- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to run in sequence.
|
||||
- [AutoPipelineBlocks](./auto_pipeline_blocks): Create pipelines that support multiple workflows.
|
||||
- [States](./modular_diffusers_states): How data is shared between blocks.
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Build custom blocks">
|
||||
|
||||
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.
|
||||
|
||||
@@ -91,23 +91,42 @@ class ImageEncoderBlock(ModularPipelineBlocks):
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Connect the two blocks by defining an [`InsertableDict`] to map the block names to the block instances. Blocks are executed in the order they're registered in `blocks_dict`.
|
||||
|
||||
Use [`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`] to create a [`~modular_pipelines.SequentialPipelineBlocks`].
|
||||
Connect the two blocks by defining a [`~modular_pipelines.SequentialPipelineBlocks`]. List the block instances in `block_classes` and their corresponding names in `block_names`. The blocks are executed in the order they appear in `block_classes`, and data flows from one block to the next through [`~modular_pipelines.PipelineState`].
|
||||
|
||||
```py
|
||||
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
|
||||
class ImageProcessingStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
# auto_docstring
|
||||
"""
|
||||
model_name = "my_model"
|
||||
block_classes = [InputBlock(), ImageEncoderBlock()]
|
||||
block_names = ["input", "image_encoder"]
|
||||
|
||||
blocks_dict = InsertableDict()
|
||||
blocks_dict["input"] = input_block
|
||||
blocks_dict["image_encoder"] = image_encoder_block
|
||||
|
||||
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Process text prompts and images for the pipeline. It:\n"
|
||||
" - Determines the batch size from the prompts.\n"
|
||||
" - Encodes the image into latent space."
|
||||
)
|
||||
```
|
||||
|
||||
Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by calling `blocks`, and for more details about the inputs and outputs, access the `docs` attribute.
|
||||
When you create a [`~modular_pipelines.SequentialPipelineBlocks`], properties like `inputs`, `intermediate_outputs`, and `expected_components` are automatically aggregated from the sub-blocks, so there is no need to define them again.
|
||||
|
||||
There are a few properties you should set:
|
||||
|
||||
- `description`: We recommend adding a description for the assembled block to explain what the combined step does.
|
||||
- `model_name`: This is automatically derived from the sub-blocks but isn't always correct, so you may need to override it.
|
||||
- `outputs`: By default this is the same as `intermediate_outputs`, but you can manually set it to control which values appear in the doc. This is useful for showing only the final outputs instead of all intermediate values.
|
||||
|
||||
These properties, together with the aggregated `inputs`, `intermediate_outputs`, and `expected_components`, are used to automatically generate the `doc` property.
|
||||
|
||||
|
||||
Print the `ImageProcessingStep` block to inspect its sub-blocks, and use `doc` for a full summary of the block's inputs, outputs, and components.
|
||||
|
||||
|
||||
```py
|
||||
blocks = ImageProcessingStep()
|
||||
print(blocks)
|
||||
print(blocks.doc)
|
||||
```
|
||||
```
|
||||
@@ -111,7 +111,7 @@ if __name__ == "__main__":
|
||||
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
|
||||
|
||||
```bash
|
||||
torchrun run_distributed.py --nproc_per_node=2
|
||||
torchrun --nproc_per_node=2 run_distributed.py
|
||||
```
|
||||
|
||||
## device_map
|
||||
|
||||
@@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
|
||||
> )
|
||||
> ```
|
||||
|
||||
### Saving custom models
|
||||
|
||||
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
|
||||
|
||||
```py
|
||||
# my_model.py
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin):
|
||||
...
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
|
||||
model = MyCustomModel(...)
|
||||
model.save_pretrained("./my_model")
|
||||
```
|
||||
|
||||
The saved `config.json` will include the `auto_map` field.
|
||||
|
||||
```json
|
||||
{
|
||||
"auto_map": {
|
||||
"AutoModel": "my_model.MyCustomModel"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.
|
||||
@@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8)
|
||||
<tr>
|
||||
<th style="text-align: center;">Face Image</th>
|
||||
<th style="text-align: center;">Video</th>
|
||||
<th style="text-align: center;">Description</th
|
||||
<th style="text-align: center;">Description</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true" style="height: auto; width: 600px;"></td>
|
||||
|
||||
133
docs/source/en/using-diffusers/helios.md
Normal file
133
docs/source/en/using-diffusers/helios.md
Normal file
@@ -0,0 +1,133 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# Helios
|
||||
|
||||
[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:
|
||||
|
||||
- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.
|
||||
- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.
|
||||
- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.
|
||||
|
||||
This guide will walk you through using Helios for use cases.
|
||||
|
||||
## Load Model Checkpoints
|
||||
|
||||
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HeliosPipeline, HeliosPyramidPipeline
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# For Best Quality
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
|
||||
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Intermediate Weight
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# For Best Efficiency
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Image-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Image</th>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Interactive-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Resources
|
||||
|
||||
Learn more about Helios with the following resources.
|
||||
- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
|
||||
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.
|
||||
@@ -132,6 +132,8 @@
|
||||
sections:
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
- local: using-diffusers/helios
|
||||
title: Helios
|
||||
|
||||
- title: Resources
|
||||
isExpanded: false
|
||||
|
||||
@@ -26,6 +26,14 @@ http://www.apache.org/licenses/LICENSE-2.0
|
||||
<th>项目名称</th>
|
||||
<th>描述</th>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/PKU-YuanGroup/Helios"> helios </a></td>
|
||||
<td>Helios:比1.3B更低开销、更快且更强的14B的实时长视频生成模型</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/PKU-YuanGroup/ConsisID"> consisid </a></td>
|
||||
<td>ConsisID:零样本身份保持的文本到视频生成模型</td>
|
||||
</tr>
|
||||
<tr style="border-top: 2px solid black">
|
||||
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
|
||||
<td>Stable Diffusion内置到Blender</td>
|
||||
|
||||
@@ -86,10 +86,7 @@ t2i_pipeline.guider
|
||||
|
||||
## 更改引导器参数
|
||||
|
||||
引导器参数可以通过 [`~ComponentSpec.create`] 方法或 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。
|
||||
|
||||
<hfoptions id="switch">
|
||||
<hfoption id="create">
|
||||
引导器参数可以通过 [`~ComponentSpec.create`] 方法以及 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。
|
||||
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
@@ -97,18 +94,6 @@ guider = guider_spec.create(guidance_scale=10)
|
||||
t2i_pipeline.update_components(guider=guider)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="update_components">
|
||||
|
||||
```py
|
||||
guider_spec = t2i_pipeline.get_component_spec("guider")
|
||||
guider_spec.config["guidance_scale"] = 10
|
||||
t2i_pipeline.update_components(guider=guider_spec)
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## 上传自定义引导器
|
||||
|
||||
在自定义引导器上调用 [`~utils.PushToHubMixin.push_to_hub`] 方法,将其分享到 Hub。
|
||||
|
||||
134
docs/source/zh/using-diffusers/helios.md
Normal file
134
docs/source/zh/using-diffusers/helios.md
Normal file
@@ -0,0 +1,134 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
# Helios
|
||||
|
||||
[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时,拥有媲美强大基线模型的生成质量,并在统一架构下原生集成了文生视频(T2V)、图生视频(I2V)和视频生视频(V2V)任务。Helios 的主要特性包括:
|
||||
|
||||
- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样),我们的模型即可生成高质量且高度连贯的分钟级视频。
|
||||
- 无需标准的加速技术(例如:KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化),作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。
|
||||
- 引入了多项优化方案,在降低显存消耗的同时,显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片(sharding)等基础设施,即可使用与图像模型相当的批大小(batch sizes)来训练 14B 的视频生成模型。
|
||||
|
||||
本指南将引导您完成 Helios 在不同场景下的使用。
|
||||
|
||||
## Load Model Checkpoints
|
||||
|
||||
模型权重可以存储在Hub上或本地的单独子文件夹中,在这种情况下,您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import HeliosPipeline, HeliosPyramidPipeline
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
# For Best Quality
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
|
||||
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Intermediate Weight
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
# For Best Efficiency
|
||||
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
|
||||
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
## Text-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="4000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Image-to-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Image</th>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
|
||||
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
|
||||
</small></td>
|
||||
<td>
|
||||
<video width="2000" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Interactive-Video Showcases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<th style="text-align: center;">Prompt</th>
|
||||
<th style="text-align: center;">Generated Video</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
|
||||
<td>
|
||||
<video width="680" controls>
|
||||
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
|
||||
</video>
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Resources
|
||||
|
||||
通过以下资源了解有关 Helios 的更多信息:
|
||||
|
||||
- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
|
||||
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。
|
||||
@@ -1232,22 +1232,49 @@ def main(args):
|
||||
id_token=args.id_token,
|
||||
)
|
||||
|
||||
def encode_video(video, bar):
|
||||
bar.update(1)
|
||||
def encode_video(video):
|
||||
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
|
||||
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
latent_dist = vae.encode(video).latent_dist
|
||||
return latent_dist
|
||||
|
||||
# Distribute video encoding across processes: each process only encodes its own shard
|
||||
num_videos = len(train_dataset.instance_videos)
|
||||
num_procs = accelerator.num_processes
|
||||
local_rank = accelerator.process_index
|
||||
local_count = len(range(local_rank, num_videos, num_procs))
|
||||
|
||||
progress_encode_bar = tqdm(
|
||||
range(0, len(train_dataset.instance_videos)),
|
||||
desc="Loading Encode videos",
|
||||
range(local_count),
|
||||
desc="Encoding videos",
|
||||
disable=not accelerator.is_local_main_process,
|
||||
)
|
||||
train_dataset.instance_videos = [
|
||||
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
|
||||
]
|
||||
|
||||
encoded_videos = [None] * num_videos
|
||||
for i, video in enumerate(train_dataset.instance_videos):
|
||||
if i % num_procs == local_rank:
|
||||
encoded_videos[i] = encode_video(video)
|
||||
progress_encode_bar.update(1)
|
||||
progress_encode_bar.close()
|
||||
|
||||
# Broadcast encoded latent distributions so every process has the full set
|
||||
if num_procs > 1:
|
||||
import torch.distributed as dist
|
||||
|
||||
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
||||
|
||||
ref_params = next(v for v in encoded_videos if v is not None).parameters
|
||||
for i in range(num_videos):
|
||||
src = i % num_procs
|
||||
if encoded_videos[i] is not None:
|
||||
params = encoded_videos[i].parameters.contiguous()
|
||||
else:
|
||||
params = torch.empty_like(ref_params)
|
||||
dist.broadcast(params, src=src)
|
||||
encoded_videos[i] = DiagonalGaussianDistribution(params)
|
||||
|
||||
train_dataset.instance_videos = encoded_videos
|
||||
|
||||
def collate_fn(examples):
|
||||
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
@@ -17,6 +17,9 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
|
||||
sys.path.append("..")
|
||||
@@ -30,6 +33,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
|
||||
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
|
||||
class CustomDiffusion(ExamplesTestsAccelerate):
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
|
||||
66
examples/research_projects/autoencoder_rae/README.md
Normal file
66
examples/research_projects/autoencoder_rae/README.md
Normal file
@@ -0,0 +1,66 @@
|
||||
# Training AutoencoderRAE
|
||||
|
||||
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
|
||||
|
||||
It follows the same high-level training recipe as the official RAE stage-1 setup:
|
||||
- frozen encoder
|
||||
- train decoder
|
||||
- pixel reconstruction loss
|
||||
- optional encoder feature consistency loss
|
||||
|
||||
## Quickstart
|
||||
|
||||
### Resume or finetune from pretrained weights
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
### Train from scratch with a pretrained encoder
|
||||
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
|
||||
|
||||
```bash
|
||||
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
|
||||
--train_data_dir /path/to/imagenet_like_folder \
|
||||
--output_dir /tmp/autoencoder-rae \
|
||||
--resolution 256 \
|
||||
--encoder_type dinov2 \
|
||||
--encoder_name_or_path facebook/dinov2-with-registers-base \
|
||||
--encoder_input_size 224 \
|
||||
--patch_size 16 \
|
||||
--image_size 256 \
|
||||
--decoder_hidden_size 1152 \
|
||||
--decoder_num_hidden_layers 28 \
|
||||
--decoder_num_attention_heads 16 \
|
||||
--decoder_intermediate_size 4096 \
|
||||
--train_batch_size 8 \
|
||||
--learning_rate 1e-4 \
|
||||
--num_train_epochs 10 \
|
||||
--report_to wandb \
|
||||
--reconstruction_loss_type l1 \
|
||||
--use_encoder_loss \
|
||||
--encoder_loss_weight 0.1
|
||||
```
|
||||
|
||||
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
|
||||
|
||||
Dataset format is expected to be `ImageFolder`-compatible:
|
||||
|
||||
```text
|
||||
train_data_dir/
|
||||
class_a/
|
||||
img_0001.jpg
|
||||
class_b/
|
||||
img_0002.jpg
|
||||
```
|
||||
@@ -0,0 +1,405 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import ImageFolder
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
from diffusers.optimization import get_scheduler
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to an ImageFolder-style dataset root.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
|
||||
)
|
||||
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
|
||||
parser.add_argument("--resolution", type=int, default=256)
|
||||
parser.add_argument("--center_crop", action="store_true")
|
||||
parser.add_argument("--random_flip", action="store_true")
|
||||
|
||||
parser.add_argument("--train_batch_size", type=int, default=8)
|
||||
parser.add_argument("--dataloader_num_workers", type=int, default=4)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=10)
|
||||
parser.add_argument("--max_train_steps", type=int, default=None)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
|
||||
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--learning_rate", type=float, default=1e-4)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9)
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
|
||||
parser.add_argument("--lr_scheduler", type=str, default="cosine")
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=1000)
|
||||
parser.add_argument("--validation_steps", type=int, default=500)
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
|
||||
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
|
||||
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
|
||||
parser.add_argument("--encoder_hidden_size", type=int, default=768)
|
||||
parser.add_argument("--encoder_patch_size", type=int, default=14)
|
||||
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=256)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
|
||||
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
|
||||
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
|
||||
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
|
||||
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
|
||||
|
||||
parser.add_argument("--noise_tau", type=float, default=0.0)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
|
||||
|
||||
parser.add_argument(
|
||||
"--reconstruction_loss_type",
|
||||
type=str,
|
||||
choices=["l1", "mse"],
|
||||
default="l1",
|
||||
help="Pixel reconstruction loss.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--encoder_loss_weight",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Weight for encoder feature consistency loss in the training loop.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_encoder_loss",
|
||||
action="store_true",
|
||||
help="Enable encoder feature consistency loss term in the training loop.",
|
||||
)
|
||||
parser.add_argument("--report_to", type=str, default="tensorboard")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def build_transforms(args):
|
||||
image_transforms = [
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
|
||||
]
|
||||
if args.random_flip:
|
||||
image_transforms.append(transforms.RandomHorizontalFlip())
|
||||
image_transforms.append(transforms.ToTensor())
|
||||
return transforms.Compose(image_transforms)
|
||||
|
||||
|
||||
def compute_losses(
|
||||
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
|
||||
):
|
||||
decoded = model(pixel_values).sample
|
||||
|
||||
if decoded.shape[-2:] != pixel_values.shape[-2:]:
|
||||
raise ValueError(
|
||||
"Training requires matching reconstruction and target sizes, got "
|
||||
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
|
||||
)
|
||||
|
||||
if reconstruction_loss_type == "l1":
|
||||
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
|
||||
else:
|
||||
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
|
||||
|
||||
encoder_loss = torch.zeros_like(reconstruction_loss)
|
||||
if use_encoder_loss and encoder_loss_weight > 0:
|
||||
base_model = model.module if hasattr(model, "module") else model
|
||||
target_encoder_input = base_model._resize_and_normalize(pixel_values)
|
||||
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
|
||||
|
||||
encoder_forward_kwargs = {"model": base_model.encoder}
|
||||
if base_model.config.encoder_type == "mae":
|
||||
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
|
||||
with torch.no_grad():
|
||||
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
|
||||
reconstructed_tokens = base_model._encoder_forward_fn(
|
||||
images=reconstructed_encoder_input, **encoder_forward_kwargs
|
||||
)
|
||||
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
|
||||
|
||||
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
|
||||
return decoded, loss, reconstruction_loss, encoder_loss
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict, prefix=""):
|
||||
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
|
||||
"""Load pretrained HF transformers encoder weights into the model's encoder."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
state_dict = hf_encoder.state_dict()
|
||||
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
model.encoder.load_state_dict(state_dict, strict=False)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
if args.resolution != args.image_size:
|
||||
raise ValueError(
|
||||
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
|
||||
"for stage-1 reconstruction loss."
|
||||
)
|
||||
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
project_config=accelerator_project_config,
|
||||
log_with=args.report_to,
|
||||
)
|
||||
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example[0] for example in examples]).float()
|
||||
return {"pixel_values": pixel_values}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
if args.pretrained_model_name_or_path is not None:
|
||||
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
|
||||
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
|
||||
else:
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=args.encoder_hidden_size,
|
||||
encoder_patch_size=args.encoder_patch_size,
|
||||
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
|
||||
decoder_hidden_size=args.decoder_hidden_size,
|
||||
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
|
||||
decoder_num_attention_heads=args.decoder_num_attention_heads,
|
||||
decoder_intermediate_size=args.decoder_intermediate_size,
|
||||
patch_size=args.patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
noise_tau=args.noise_tau,
|
||||
reshape_to_2d=args.reshape_to_2d,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
if args.encoder_name_or_path is not None:
|
||||
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
|
||||
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
|
||||
model.encoder.requires_grad_(False)
|
||||
model.decoder.requires_grad_(True)
|
||||
model.train()
|
||||
|
||||
optimizer = torch.optim.AdamW(
|
||||
(p for p in model.parameters() if p.requires_grad),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
if overrode_max_train_steps:
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
|
||||
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
global_step = 0
|
||||
|
||||
for epoch in range(args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
pixel_values = batch["pixel_values"]
|
||||
|
||||
_, loss, reconstruction_loss, encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
logs = {
|
||||
"loss": loss.detach().item(),
|
||||
"reconstruction_loss": reconstruction_loss.detach().item(),
|
||||
"encoder_loss": encoder_loss.detach().item(),
|
||||
"lr": lr_scheduler.get_last_lr()[0],
|
||||
}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step % args.validation_steps == 0:
|
||||
with torch.no_grad():
|
||||
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
|
||||
model,
|
||||
pixel_values,
|
||||
reconstruction_loss_type=args.reconstruction_loss_type,
|
||||
use_encoder_loss=args.use_encoder_loss,
|
||||
encoder_loss_weight=args.encoder_loss_weight,
|
||||
)
|
||||
accelerator.log(
|
||||
{
|
||||
"val/loss": val_loss.detach().item(),
|
||||
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
|
||||
"val/encoder_loss": val_encoder_loss.detach().item(),
|
||||
},
|
||||
step=global_step,
|
||||
)
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(save_path)
|
||||
logger.info(f"Saved checkpoint to {save_path}")
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
unwrapped_model.save_pretrained(args.output_dir)
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -94,9 +94,15 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/depth \
|
||||
--output_path converted/transfer/2b/general/depth/pipeline \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/depth/models
|
||||
|
||||
# edge
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
|
||||
|
||||
@@ -120,9 +126,15 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/blur \
|
||||
--output_path converted/transfer/2b/general/blur/pipeline \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/blur/models
|
||||
|
||||
# seg
|
||||
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
|
||||
|
||||
@@ -130,8 +142,14 @@ python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/seg \
|
||||
--output_path converted/transfer/2b/general/seg/pipeline \
|
||||
--save_pipeline
|
||||
|
||||
python scripts/convert_cosmos_to_diffusers.py \
|
||||
--transformer_type Cosmos-2.5-Transfer-General-2B \
|
||||
--transformer_ckpt_path $transformer_ckpt_path \
|
||||
--vae_type wan2.1 \
|
||||
--output_path converted/transfer/2b/general/seg/models
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
403
scripts/convert_rae_to_diffusers.py
Normal file
403
scripts/convert_rae_to_diffusers.py
Normal file
@@ -0,0 +1,403 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from huggingface_hub import HfApi, hf_hub_download
|
||||
|
||||
from diffusers import AutoencoderRAE
|
||||
|
||||
|
||||
DECODER_CONFIGS = {
|
||||
"ViTB": {
|
||||
"decoder_hidden_size": 768,
|
||||
"decoder_intermediate_size": 3072,
|
||||
"decoder_num_attention_heads": 12,
|
||||
"decoder_num_hidden_layers": 12,
|
||||
},
|
||||
"ViTL": {
|
||||
"decoder_hidden_size": 1024,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 24,
|
||||
},
|
||||
"ViTXL": {
|
||||
"decoder_hidden_size": 1152,
|
||||
"decoder_intermediate_size": 4096,
|
||||
"decoder_num_attention_heads": 16,
|
||||
"decoder_num_hidden_layers": 28,
|
||||
},
|
||||
}
|
||||
|
||||
ENCODER_DEFAULT_NAME_OR_PATH = {
|
||||
"dinov2": "facebook/dinov2-with-registers-base",
|
||||
"siglip2": "google/siglip2-base-patch16-256",
|
||||
"mae": "facebook/vit-mae-base",
|
||||
}
|
||||
|
||||
ENCODER_HIDDEN_SIZE = {
|
||||
"dinov2": 768,
|
||||
"siglip2": 768,
|
||||
"mae": 768,
|
||||
}
|
||||
|
||||
ENCODER_PATCH_SIZE = {
|
||||
"dinov2": 14,
|
||||
"siglip2": 16,
|
||||
"mae": 16,
|
||||
}
|
||||
|
||||
DEFAULT_DECODER_SUBDIR = {
|
||||
"dinov2": "decoders/dinov2/wReg_base",
|
||||
"mae": "decoders/mae/base_p16",
|
||||
"siglip2": "decoders/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DEFAULT_STATS_SUBDIR = {
|
||||
"dinov2": "stats/dinov2/wReg_base",
|
||||
"mae": "stats/mae/base_p16",
|
||||
"siglip2": "stats/siglip2/base_p16_i256",
|
||||
}
|
||||
|
||||
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
|
||||
STATS_FILE_CANDIDATES = ("stat.pt",)
|
||||
|
||||
|
||||
def dataset_case_candidates(name: str) -> tuple[str, ...]:
|
||||
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
|
||||
|
||||
|
||||
class RepoAccessor:
|
||||
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
|
||||
self.repo_or_path = repo_or_path
|
||||
self.cache_dir = cache_dir
|
||||
self.local_root: Path | None = None
|
||||
self.repo_id: str | None = None
|
||||
self.repo_files: set[str] | None = None
|
||||
|
||||
root = Path(repo_or_path)
|
||||
if root.exists() and root.is_dir():
|
||||
self.local_root = root
|
||||
else:
|
||||
self.repo_id = repo_or_path
|
||||
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
|
||||
|
||||
def exists(self, relative_path: str) -> bool:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return (self.local_root / relative_path).is_file()
|
||||
return relative_path in self.repo_files
|
||||
|
||||
def fetch(self, relative_path: str) -> Path:
|
||||
relative_path = relative_path.replace("\\", "/")
|
||||
if self.local_root is not None:
|
||||
return self.local_root / relative_path
|
||||
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
|
||||
return Path(downloaded)
|
||||
|
||||
|
||||
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
|
||||
state_dict = maybe_wrapped
|
||||
for k in ("model", "module", "state_dict"):
|
||||
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
|
||||
state_dict = state_dict[k]
|
||||
|
||||
out = dict(state_dict)
|
||||
if len(out) > 0 and all(key.startswith("module.") for key in out):
|
||||
out = {key[len("module.") :]: value for key, value in out.items()}
|
||||
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
|
||||
out = {key[len("decoder.") :]: value for key, value in out.items()}
|
||||
return out
|
||||
|
||||
|
||||
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
|
||||
|
||||
Example mappings:
|
||||
- `...attention.attention.query.*` -> `...attention.to_q.*`
|
||||
- `...attention.attention.key.*` -> `...attention.to_k.*`
|
||||
- `...attention.attention.value.*` -> `...attention.to_v.*`
|
||||
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
|
||||
"""
|
||||
remapped: dict[str, Any] = {}
|
||||
for key, value in state_dict.items():
|
||||
new_key = key
|
||||
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
|
||||
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
|
||||
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
|
||||
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
|
||||
remapped[new_key] = value
|
||||
return remapped
|
||||
|
||||
|
||||
def resolve_decoder_file(
|
||||
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
|
||||
) -> str:
|
||||
if decoder_checkpoint is not None:
|
||||
if accessor.exists(decoder_checkpoint):
|
||||
return decoder_checkpoint
|
||||
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
|
||||
|
||||
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
|
||||
for name in DECODER_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
|
||||
)
|
||||
|
||||
|
||||
def resolve_stats_file(
|
||||
accessor: RepoAccessor,
|
||||
encoder_type: str,
|
||||
dataset_name: str,
|
||||
stats_checkpoint: str | None,
|
||||
) -> str | None:
|
||||
if stats_checkpoint is not None:
|
||||
if accessor.exists(stats_checkpoint):
|
||||
return stats_checkpoint
|
||||
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
|
||||
|
||||
base = DEFAULT_STATS_SUBDIR[encoder_type]
|
||||
for dataset in dataset_case_candidates(dataset_name):
|
||||
for name in STATS_FILE_CANDIDATES:
|
||||
candidate = f"{base}/{dataset}/{name}"
|
||||
if accessor.exists(candidate):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
|
||||
if not isinstance(stats_obj, dict):
|
||||
return None, None
|
||||
|
||||
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
|
||||
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
|
||||
|
||||
mean = stats_obj.get("mean", None)
|
||||
var = stats_obj.get("var", None)
|
||||
if mean is None and var is None:
|
||||
return None, None
|
||||
|
||||
latents_std = None
|
||||
if var is not None:
|
||||
if isinstance(var, torch.Tensor):
|
||||
latents_std = torch.sqrt(var + 1e-5)
|
||||
else:
|
||||
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
|
||||
return mean, latents_std
|
||||
|
||||
|
||||
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
|
||||
"""Remove final layernorm weight/bias from encoder state dict.
|
||||
|
||||
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
|
||||
Stripping these keys means the model keeps its default init values, which
|
||||
is functionally equivalent to setting elementwise_affine=False.
|
||||
"""
|
||||
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
|
||||
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
|
||||
|
||||
|
||||
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
|
||||
"""Download the HF encoder and extract the state dict for the inner model."""
|
||||
if encoder_type == "dinov2":
|
||||
from transformers import Dinov2WithRegistersModel
|
||||
|
||||
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
elif encoder_type == "siglip2":
|
||||
from transformers import SiglipModel
|
||||
|
||||
# SiglipModel.vision_model is a SiglipVisionTransformer.
|
||||
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
|
||||
# under .vision_model, so we add the prefix to match the diffusers key layout.
|
||||
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
|
||||
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
|
||||
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
|
||||
elif encoder_type == "mae":
|
||||
from transformers import ViTMAEForPreTraining
|
||||
|
||||
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
|
||||
sd = hf_model.state_dict()
|
||||
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type: {encoder_type}")
|
||||
|
||||
|
||||
def convert(args: argparse.Namespace) -> None:
|
||||
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
|
||||
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
|
||||
|
||||
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
|
||||
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
|
||||
|
||||
print(f"Using decoder checkpoint: {decoder_relpath}")
|
||||
if stats_relpath is not None:
|
||||
print(f"Using stats checkpoint: {stats_relpath}")
|
||||
else:
|
||||
print("No stats checkpoint found; conversion will proceed without latent stats.")
|
||||
|
||||
if args.dry_run:
|
||||
return
|
||||
|
||||
decoder_path = accessor.fetch(decoder_relpath)
|
||||
decoder_obj = torch.load(decoder_path, map_location="cpu")
|
||||
decoder_state_dict = unwrap_state_dict(decoder_obj)
|
||||
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
|
||||
|
||||
latents_mean, latents_std = None, None
|
||||
if stats_relpath is not None:
|
||||
stats_path = accessor.fetch(stats_relpath)
|
||||
stats_obj = torch.load(stats_path, map_location="cpu")
|
||||
latents_mean, latents_std = extract_latent_stats(stats_obj)
|
||||
|
||||
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
|
||||
|
||||
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
|
||||
from transformers import AutoConfig, AutoImageProcessor
|
||||
|
||||
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
|
||||
encoder_norm_mean = list(proc.image_mean)
|
||||
encoder_norm_std = list(proc.image_std)
|
||||
|
||||
# Read encoder hidden size and patch size from HF config
|
||||
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
|
||||
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
|
||||
# For models like SigLIP that nest vision config
|
||||
if hasattr(hf_config, "vision_config"):
|
||||
hf_config = hf_config.vision_config
|
||||
encoder_hidden_size = hf_config.hidden_size
|
||||
encoder_patch_size = hf_config.patch_size
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Load the actual encoder weights from HF to include in the saved model
|
||||
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
|
||||
|
||||
# Build model on meta device to avoid double init overhead
|
||||
with torch.device("meta"):
|
||||
model = AutoencoderRAE(
|
||||
encoder_type=args.encoder_type,
|
||||
encoder_hidden_size=encoder_hidden_size,
|
||||
encoder_patch_size=encoder_patch_size,
|
||||
encoder_input_size=args.encoder_input_size,
|
||||
patch_size=args.patch_size,
|
||||
image_size=args.image_size,
|
||||
num_channels=args.num_channels,
|
||||
encoder_norm_mean=encoder_norm_mean,
|
||||
encoder_norm_std=encoder_norm_std,
|
||||
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
|
||||
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
|
||||
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
|
||||
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
|
||||
latents_mean=latents_mean,
|
||||
latents_std=latents_std,
|
||||
scaling_factor=args.scaling_factor,
|
||||
)
|
||||
|
||||
# Assemble full state dict and load with assign=True
|
||||
full_state_dict = {}
|
||||
|
||||
# Encoder weights (prefixed with "encoder.")
|
||||
for k, v in encoder_state_dict.items():
|
||||
full_state_dict[f"encoder.{k}"] = v
|
||||
|
||||
# Decoder weights (prefixed with "decoder.")
|
||||
for k, v in decoder_state_dict.items():
|
||||
full_state_dict[f"decoder.{k}"] = v
|
||||
|
||||
# Buffers from config
|
||||
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
if latents_mean is not None:
|
||||
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
|
||||
full_state_dict["_latents_mean"] = latents_mean_t
|
||||
else:
|
||||
full_state_dict["_latents_mean"] = torch.zeros(1)
|
||||
if latents_std is not None:
|
||||
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
|
||||
full_state_dict["_latents_std"] = latents_std_t
|
||||
else:
|
||||
full_state_dict["_latents_std"] = torch.ones(1)
|
||||
|
||||
model.load_state_dict(full_state_dict, strict=False, assign=True)
|
||||
|
||||
# Verify no critical keys are missing
|
||||
model_keys = {name for name, _ in model.named_parameters()}
|
||||
model_keys |= {name for name, _ in model.named_buffers()}
|
||||
loaded_keys = set(full_state_dict.keys())
|
||||
missing = model_keys - loaded_keys
|
||||
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
|
||||
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
|
||||
if missing - allowed_missing:
|
||||
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
|
||||
|
||||
output_path = Path(args.output_path)
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
model.save_pretrained(output_path)
|
||||
|
||||
if args.verify_load:
|
||||
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
|
||||
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
|
||||
if not isinstance(loaded_model, AutoencoderRAE):
|
||||
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
|
||||
print("Verification passed.")
|
||||
|
||||
print(f"Saved converted AutoencoderRAE to: {output_path}")
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
|
||||
parser.add_argument(
|
||||
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
|
||||
|
||||
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
|
||||
parser.add_argument(
|
||||
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
|
||||
)
|
||||
|
||||
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
|
||||
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
|
||||
)
|
||||
|
||||
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
|
||||
parser.add_argument("--encoder_input_size", type=int, default=224)
|
||||
parser.add_argument("--patch_size", type=int, default=16)
|
||||
parser.add_argument("--image_size", type=int, default=None)
|
||||
parser.add_argument("--num_channels", type=int, default=3)
|
||||
parser.add_argument("--scaling_factor", type=float, default=1.0)
|
||||
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
|
||||
parser.add_argument(
|
||||
"--verify_load",
|
||||
action="store_true",
|
||||
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = parse_args()
|
||||
convert(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
setup.py
6
setup.py
@@ -101,6 +101,7 @@ _deps = [
|
||||
"datasets",
|
||||
"filelock",
|
||||
"flax>=0.4.1",
|
||||
"ftfy",
|
||||
"hf-doc-builder>=0.3.0",
|
||||
"httpx<1.0.0",
|
||||
"huggingface-hub>=0.34.0,<2.0",
|
||||
@@ -111,7 +112,6 @@ _deps = [
|
||||
"jax>=0.4.1",
|
||||
"jaxlib>=0.4.1",
|
||||
"Jinja2",
|
||||
"k-diffusion==0.0.12",
|
||||
"torchsde",
|
||||
"note_seq",
|
||||
"librosa",
|
||||
@@ -222,13 +222,14 @@ extras["docs"] = deps_list("hf-doc-builder")
|
||||
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
|
||||
extras["test"] = deps_list(
|
||||
"compel",
|
||||
"ftfy",
|
||||
"GitPython",
|
||||
"datasets",
|
||||
"Jinja2",
|
||||
"invisible-watermark",
|
||||
"k-diffusion",
|
||||
"librosa",
|
||||
"parameterized",
|
||||
"protobuf",
|
||||
"pytest",
|
||||
"pytest-timeout",
|
||||
"pytest-xdist",
|
||||
@@ -237,6 +238,7 @@ extras["test"] = deps_list(
|
||||
"sentencepiece",
|
||||
"scipy",
|
||||
"tiktoken",
|
||||
"torchsde",
|
||||
"torchvision",
|
||||
"transformers",
|
||||
"phonemizer",
|
||||
|
||||
@@ -10,7 +10,6 @@ from .utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_flax_available,
|
||||
is_gguf_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_nvidia_modelopt_available,
|
||||
@@ -50,8 +49,6 @@ _import_structure = {
|
||||
"is_flax_available",
|
||||
"is_inflect_available",
|
||||
"is_invisible_watermark_available",
|
||||
"is_k_diffusion_available",
|
||||
"is_k_diffusion_version",
|
||||
"is_librosa_available",
|
||||
"is_note_seq_available",
|
||||
"is_onnx_available",
|
||||
@@ -205,6 +202,7 @@ else:
|
||||
"AutoencoderKLTemporalDecoder",
|
||||
"AutoencoderKLWan",
|
||||
"AutoencoderOobleck",
|
||||
"AutoencoderRAE",
|
||||
"AutoencoderTiny",
|
||||
"AutoModel",
|
||||
"BriaFiboTransformer2DModel",
|
||||
@@ -230,6 +228,7 @@ else:
|
||||
"FluxMultiControlNetModel",
|
||||
"FluxTransformer2DModel",
|
||||
"GlmImageTransformer2DModel",
|
||||
"HeliosTransformer3DModel",
|
||||
"HiDreamImageTransformer2DModel",
|
||||
"HunyuanDiT2DControlNetModel",
|
||||
"HunyuanDiT2DModel",
|
||||
@@ -362,6 +361,8 @@ else:
|
||||
"FlowMatchEulerDiscreteScheduler",
|
||||
"FlowMatchHeunDiscreteScheduler",
|
||||
"FlowMatchLCMScheduler",
|
||||
"HeliosDMDScheduler",
|
||||
"HeliosScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"IPNDMScheduler",
|
||||
"KarrasVeScheduler",
|
||||
@@ -518,6 +519,8 @@ else:
|
||||
"FluxPipeline",
|
||||
"FluxPriorReduxPipeline",
|
||||
"GlmImagePipeline",
|
||||
"HeliosPipeline",
|
||||
"HeliosPyramidPipeline",
|
||||
"HiDreamImagePipeline",
|
||||
"HunyuanDiTControlNetPipeline",
|
||||
"HunyuanDiTPAGPipeline",
|
||||
@@ -731,19 +734,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
|
||||
|
||||
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
|
||||
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
|
||||
]
|
||||
|
||||
else:
|
||||
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -985,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
AutoModel,
|
||||
BriaFiboTransformer2DModel,
|
||||
@@ -1010,6 +1001,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxMultiControlNetModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HeliosTransformer3DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DControlNetModel,
|
||||
HunyuanDiT2DModel,
|
||||
@@ -1138,6 +1130,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FlowMatchEulerDiscreteScheduler,
|
||||
FlowMatchHeunDiscreteScheduler,
|
||||
FlowMatchLCMScheduler,
|
||||
HeliosDMDScheduler,
|
||||
HeliosScheduler,
|
||||
HeunDiscreteScheduler,
|
||||
IPNDMScheduler,
|
||||
KarrasVeScheduler,
|
||||
@@ -1273,6 +1267,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
FluxPipeline,
|
||||
FluxPriorReduxPipeline,
|
||||
GlmImagePipeline,
|
||||
HeliosPipeline,
|
||||
HeliosPyramidPipeline,
|
||||
HiDreamImagePipeline,
|
||||
HunyuanDiTControlNetPipeline,
|
||||
HunyuanDiTPAGPipeline,
|
||||
@@ -1469,14 +1465,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ZImagePipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
|
||||
else:
|
||||
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
|
||||
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
|
||||
# with open(CONFIG, "w") as f:
|
||||
# json.dump(automap, f)
|
||||
with open("requirements.txt", "w") as f:
|
||||
f.write("")
|
||||
|
||||
def _choose_block(self, candidates, chosen=None):
|
||||
for cls, base in candidates:
|
||||
|
||||
@@ -107,6 +107,38 @@ class ConfigMixin:
|
||||
has_compatibles = False
|
||||
|
||||
_deprecated_kwargs = []
|
||||
_auto_class = None
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoModel"):
|
||||
"""
|
||||
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
|
||||
trust_remote_code=True)`.
|
||||
|
||||
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
|
||||
to this class's module and class name.
|
||||
|
||||
Args:
|
||||
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
|
||||
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
|
||||
Currently only `"AutoModel"` is supported.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
|
||||
|
||||
class MyCustomModel(ModelMixin, ConfigMixin): ...
|
||||
|
||||
|
||||
MyCustomModel.register_for_auto_class("AutoModel")
|
||||
```
|
||||
"""
|
||||
if auto_class != "AutoModel":
|
||||
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def register_to_config(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
@@ -621,6 +653,12 @@ class ConfigMixin:
|
||||
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
||||
_ = config_dict.pop("_pre_quantization_dtype", None)
|
||||
|
||||
if getattr(self, "_auto_class", None) is not None:
|
||||
module = self.__class__.__module__.split(".")[-1]
|
||||
auto_map = config_dict.get("auto_map", {})
|
||||
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
|
||||
config_dict["auto_map"] = auto_map
|
||||
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: str | os.PathLike):
|
||||
|
||||
@@ -8,6 +8,7 @@ deps = {
|
||||
"datasets": "datasets",
|
||||
"filelock": "filelock",
|
||||
"flax": "flax>=0.4.1",
|
||||
"ftfy": "ftfy",
|
||||
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
||||
"httpx": "httpx<1.0.0",
|
||||
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
|
||||
@@ -18,7 +19,6 @@ deps = {
|
||||
"jax": "jax>=0.4.1",
|
||||
"jaxlib": "jaxlib>=0.4.1",
|
||||
"Jinja2": "Jinja2",
|
||||
"k-diffusion": "k-diffusion==0.0.12",
|
||||
"torchsde": "torchsde",
|
||||
"note_seq": "note_seq",
|
||||
"librosa": "librosa",
|
||||
|
||||
@@ -54,7 +54,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
|
||||
self._timestep: torch.LongTensor = None
|
||||
self._count_prepared = 0
|
||||
self._input_fields: dict[str, str | tuple[str, str]] = None
|
||||
self._enabled = True
|
||||
self._enabled = enabled
|
||||
|
||||
if not (0.0 <= start < 1.0):
|
||||
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")
|
||||
|
||||
@@ -48,6 +48,7 @@ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
|
||||
torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d,
|
||||
torch.nn.Linear,
|
||||
torch.nn.Embedding,
|
||||
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
|
||||
# because of double invocation of the same norm layer in CogVideoXLayerNorm
|
||||
)
|
||||
|
||||
@@ -307,6 +307,17 @@ class GroupOffloadingHook(ModelHook):
|
||||
if self.group.onload_leader == module:
|
||||
if self.group.onload_self:
|
||||
self.group.onload_()
|
||||
else:
|
||||
# onload_self=False means this group relies on prefetching from a previous group.
|
||||
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
|
||||
# the prefetch chain may not cover them if they were absent during the first forward pass
|
||||
# when the execution order was traced. In that case, their weights remain on offload_device,
|
||||
# so we fall back to a synchronous onload here.
|
||||
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
|
||||
if params and params[0].device == self.group.offload_device:
|
||||
self.group.onload_()
|
||||
if self.group.stream is not None:
|
||||
self.group.stream.synchronize()
|
||||
|
||||
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
|
||||
if should_onload_next_group:
|
||||
|
||||
@@ -78,6 +78,7 @@ if is_torch_available():
|
||||
"SanaLoraLoaderMixin",
|
||||
"Lumina2LoraLoaderMixin",
|
||||
"WanLoraLoaderMixin",
|
||||
"HeliosLoraLoaderMixin",
|
||||
"KandinskyLoraLoaderMixin",
|
||||
"HiDreamImageLoraLoaderMixin",
|
||||
"SkyReelsV2LoraLoaderMixin",
|
||||
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
CogView4LoraLoaderMixin,
|
||||
Flux2LoraLoaderMixin,
|
||||
FluxLoraLoaderMixin,
|
||||
HeliosLoraLoaderMixin,
|
||||
HiDreamImageLoraLoaderMixin,
|
||||
HunyuanVideoLoraLoaderMixin,
|
||||
KandinskyLoraLoaderMixin,
|
||||
|
||||
@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
)
|
||||
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
|
||||
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
|
||||
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
|
||||
if has_diffb:
|
||||
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
|
||||
if zero_status_diff_b:
|
||||
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
|
||||
state_dict = {
|
||||
_custom_replace(k, limit_substrings): v
|
||||
for k, v in state_dict.items()
|
||||
if k.startswith(("lora_unet_", "lora_te_"))
|
||||
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
|
||||
}
|
||||
|
||||
if any("text_projection" in k for k in state_dict):
|
||||
@@ -2519,6 +2519,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
if has_default:
|
||||
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
|
||||
|
||||
# Normalize ZImage-specific dot-separated module names to underscore form so they
|
||||
# match the diffusers model parameter names (context_refiner, noise_refiner).
|
||||
state_dict = {
|
||||
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
|
||||
for k, v in state_dict.items()
|
||||
}
|
||||
|
||||
converted_state_dict = {}
|
||||
all_keys = list(state_dict.keys())
|
||||
down_key = ".lora_down.weight"
|
||||
@@ -2529,19 +2536,18 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
|
||||
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
if has_non_diffusers_lora_id:
|
||||
|
||||
def get_alpha_scales(down_weight, alpha_key):
|
||||
rank = down_weight.shape[0]
|
||||
alpha = state_dict.pop(alpha_key).item()
|
||||
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
|
||||
scale_down = scale
|
||||
scale_up = 1.0
|
||||
while scale_down * 2 < scale_up:
|
||||
scale_down *= 2
|
||||
scale_up /= 2
|
||||
return scale_down, scale_up
|
||||
|
||||
for k in all_keys:
|
||||
if k.endswith(down_key):
|
||||
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
|
||||
@@ -2554,13 +2560,69 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
|
||||
converted_state_dict[diffusers_down_key] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Already in diffusers format (lora_A/lora_B), just pop
|
||||
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
|
||||
elif has_diffusers_lora_id:
|
||||
for k in all_keys:
|
||||
if a_key in k or b_key in k:
|
||||
converted_state_dict[k] = state_dict.pop(k)
|
||||
elif ".alpha" in k:
|
||||
if k.endswith(a_key):
|
||||
diffusers_up_key = k.replace(a_key, b_key)
|
||||
alpha_key = k.replace(a_key, ".alpha")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(diffusers_up_key)
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[k] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up_key] = up_weight * scale_up
|
||||
|
||||
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
|
||||
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
|
||||
# lora weight names and also include redundant keys:
|
||||
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
|
||||
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
|
||||
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
|
||||
lora_dot_down_key = ".lora.down.weight"
|
||||
lora_dot_up_key = ".lora.up.weight"
|
||||
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
|
||||
|
||||
if has_lora_dot_format:
|
||||
dot_keys = list(state_dict.keys())
|
||||
for k in dot_keys:
|
||||
if lora_dot_down_key not in k:
|
||||
continue
|
||||
if k not in state_dict:
|
||||
continue # already popped by a prior iteration
|
||||
|
||||
base = k[: -len(lora_dot_down_key)]
|
||||
|
||||
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
|
||||
if base.endswith(".qkv"):
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
state_dict.pop(base + ".alpha", None)
|
||||
continue
|
||||
|
||||
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
|
||||
if re.search(r"\.out$", base) and ".to_out" not in base:
|
||||
state_dict.pop(k)
|
||||
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
|
||||
continue
|
||||
|
||||
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
|
||||
norm_k = re.sub(
|
||||
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
|
||||
r".to_\1" + lora_dot_down_key,
|
||||
k,
|
||||
)
|
||||
norm_base = norm_k[: -len(lora_dot_down_key)]
|
||||
alpha_key = norm_base + ".alpha"
|
||||
|
||||
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
|
||||
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
|
||||
|
||||
down_weight = state_dict.pop(k)
|
||||
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
|
||||
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
|
||||
converted_state_dict[diffusers_down] = down_weight * scale_down
|
||||
converted_state_dict[diffusers_up] = up_weight * scale_up
|
||||
|
||||
if len(state_dict) > 0:
|
||||
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")
|
||||
|
||||
@@ -3440,6 +3440,207 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HeliosLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`].
|
||||
"""
|
||||
|
||||
_lora_loadable_modules = ["transformer"]
|
||||
transformer_name = TRANSFORMER_NAME
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def lora_state_dict(
|
||||
cls,
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
|
||||
"""
|
||||
# Load the main state dict first which has the LoRA layers for either of
|
||||
# transformer and text encoder or both.
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", None)
|
||||
use_safetensors = kwargs.pop("use_safetensors", None)
|
||||
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
allow_pickle = True
|
||||
|
||||
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
|
||||
|
||||
state_dict, metadata = _fetch_state_dict(
|
||||
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
|
||||
weight_name=weight_name,
|
||||
use_safetensors=use_safetensors,
|
||||
local_files_only=local_files_only,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
allow_pickle=allow_pickle,
|
||||
)
|
||||
if any(k.startswith("diffusion_model.") for k in state_dict):
|
||||
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
|
||||
elif any(k.startswith("lora_unet_") for k in state_dict):
|
||||
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
|
||||
|
||||
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
|
||||
if is_dora_scale_present:
|
||||
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
out = (state_dict, metadata) if return_lora_metadata else state_dict
|
||||
return out
|
||||
|
||||
def load_lora_weights(
|
||||
self,
|
||||
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
|
||||
adapter_name: str | None = None,
|
||||
hotswap: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
|
||||
"""
|
||||
if not USE_PEFT_BACKEND:
|
||||
raise ValueError("PEFT backend is required for this method.")
|
||||
|
||||
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# if a dict is passed, copy it instead of modifying it inplace
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
|
||||
|
||||
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
|
||||
kwargs["return_lora_metadata"] = True
|
||||
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
is_correct_format = all("lora" in key for key in state_dict.keys())
|
||||
if not is_correct_format:
|
||||
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
|
||||
|
||||
self.load_lora_into_transformer(
|
||||
state_dict,
|
||||
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=self,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
|
||||
def load_lora_into_transformer(
|
||||
cls,
|
||||
state_dict,
|
||||
transformer,
|
||||
adapter_name=None,
|
||||
_pipeline=None,
|
||||
low_cpu_mem_usage=False,
|
||||
hotswap: bool = False,
|
||||
metadata=None,
|
||||
):
|
||||
"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
|
||||
"""
|
||||
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
|
||||
raise ValueError(
|
||||
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
|
||||
)
|
||||
|
||||
# Load the layers corresponding to transformer.
|
||||
logger.info(f"Loading {cls.transformer_name}.")
|
||||
transformer.load_lora_adapter(
|
||||
state_dict,
|
||||
network_alphas=None,
|
||||
adapter_name=adapter_name,
|
||||
metadata=metadata,
|
||||
_pipeline=_pipeline,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
hotswap=hotswap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
|
||||
def save_lora_weights(
|
||||
cls,
|
||||
save_directory: str | os.PathLike,
|
||||
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = True,
|
||||
transformer_lora_adapter_metadata: dict | None = None,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
|
||||
"""
|
||||
lora_layers = {}
|
||||
lora_metadata = {}
|
||||
|
||||
if transformer_lora_layers:
|
||||
lora_layers[cls.transformer_name] = transformer_lora_layers
|
||||
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
|
||||
|
||||
if not lora_layers:
|
||||
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
|
||||
|
||||
cls._save_lora_weights(
|
||||
save_directory=save_directory,
|
||||
lora_layers=lora_layers,
|
||||
lora_metadata=lora_metadata,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
|
||||
def fuse_lora(
|
||||
self,
|
||||
components: list[str] = ["transformer"],
|
||||
lora_scale: float = 1.0,
|
||||
safe_fusing: bool = False,
|
||||
adapter_names: list[str] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
|
||||
"""
|
||||
super().fuse_lora(
|
||||
components=components,
|
||||
lora_scale=lora_scale,
|
||||
safe_fusing=safe_fusing,
|
||||
adapter_names=adapter_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
|
||||
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
|
||||
r"""
|
||||
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
|
||||
"""
|
||||
super().unfuse_lora(components=components, **kwargs)
|
||||
|
||||
|
||||
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
|
||||
r"""
|
||||
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
|
||||
@@ -5472,6 +5673,10 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
|
||||
logger.warning(warn_msg)
|
||||
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
|
||||
|
||||
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
|
||||
if is_peft_format:
|
||||
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
|
||||
|
||||
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
|
||||
if is_ai_toolkit:
|
||||
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)
|
||||
|
||||
@@ -51,6 +51,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
|
||||
"FluxTransformer2DModel": lambda model_cls, weights: weights,
|
||||
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"MochiTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
|
||||
|
||||
@@ -22,7 +22,12 @@ from tokenizers import Tokenizer as TokenizerFast
|
||||
from torch import nn
|
||||
|
||||
from ..models.modeling_utils import load_state_dict
|
||||
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
|
||||
from ..utils import (
|
||||
_get_model_file,
|
||||
is_accelerate_available,
|
||||
is_transformers_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
|
||||
@@ -49,6 +49,7 @@ if is_torch_available():
|
||||
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
|
||||
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
|
||||
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
|
||||
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
|
||||
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
|
||||
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
||||
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
||||
@@ -100,6 +101,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
|
||||
@@ -167,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
AutoencoderKLTemporalDecoder,
|
||||
AutoencoderKLWan,
|
||||
AutoencoderOobleck,
|
||||
AutoencoderRAE,
|
||||
AutoencoderTiny,
|
||||
ConsistencyDecoderVAE,
|
||||
VQModel,
|
||||
@@ -212,6 +215,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
Flux2Transformer2DModel,
|
||||
FluxTransformer2DModel,
|
||||
GlmImageTransformer2DModel,
|
||||
HeliosTransformer3DModel,
|
||||
HiDreamImageTransformer2DModel,
|
||||
HunyuanDiT2DModel,
|
||||
HunyuanImageTransformer2DModel,
|
||||
|
||||
@@ -62,6 +62,8 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
|
||||
_REQUIRED_XLA_VERSION = "2.2"
|
||||
_REQUIRED_XFORMERS_VERSION = "0.0.29"
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
|
||||
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
|
||||
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
|
||||
@@ -73,8 +75,18 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
try:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
# Handle ABI mismatch or other import failures gracefully.
|
||||
# This can happen when flash_attn was compiled against a different PyTorch version.
|
||||
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
|
||||
_CAN_USE_FLASH_ATTN = False
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
@@ -83,26 +95,47 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
try:
|
||||
from flash_attn_interface import flash_attn_func as flash_attn_3_func
|
||||
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLASH_ATTN_3 = False
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
else:
|
||||
flash_attn_3_func = None
|
||||
flash_attn_3_varlen_func = None
|
||||
|
||||
if _CAN_USE_AITER_ATTN:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
try:
|
||||
from aiter import flash_attn_func as aiter_flash_attn_func
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_AITER_ATTN = False
|
||||
aiter_flash_attn_func = None
|
||||
else:
|
||||
aiter_flash_attn_func = None
|
||||
|
||||
if _CAN_USE_SAGE_ATTN:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
try:
|
||||
from sageattention import (
|
||||
sageattn,
|
||||
sageattn_qk_int8_pv_fp8_cuda,
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90,
|
||||
sageattn_qk_int8_pv_fp16_cuda,
|
||||
sageattn_qk_int8_pv_fp16_triton,
|
||||
sageattn_varlen,
|
||||
)
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_SAGE_ATTN = False
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp8_cuda = None
|
||||
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
sageattn_qk_int8_pv_fp16_triton = None
|
||||
sageattn_varlen = None
|
||||
else:
|
||||
sageattn = None
|
||||
sageattn_qk_int8_pv_fp16_cuda = None
|
||||
@@ -113,26 +146,48 @@ else:
|
||||
|
||||
|
||||
if _CAN_USE_FLEX_ATTN:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
try:
|
||||
# We cannot import the flex_attention function from the package directly because it is expected (from the
|
||||
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
|
||||
# compiled function.
|
||||
import torch.nn.attention.flex_attention as flex_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_FLEX_ATTN = False
|
||||
flex_attention = None
|
||||
else:
|
||||
flex_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_NPU_ATTN:
|
||||
from torch_npu import npu_fusion_attention
|
||||
try:
|
||||
from torch_npu import npu_fusion_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_NPU_ATTN = False
|
||||
npu_fusion_attention = None
|
||||
else:
|
||||
npu_fusion_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XLA_ATTN:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
try:
|
||||
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XLA_ATTN = False
|
||||
xla_flash_attention = None
|
||||
else:
|
||||
xla_flash_attention = None
|
||||
|
||||
|
||||
if _CAN_USE_XFORMERS_ATTN:
|
||||
import xformers.ops as xops
|
||||
try:
|
||||
import xformers.ops as xops
|
||||
except (ImportError, OSError, RuntimeError) as e:
|
||||
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
|
||||
_CAN_USE_XFORMERS_ATTN = False
|
||||
xops = None
|
||||
else:
|
||||
xops = None
|
||||
|
||||
@@ -158,8 +213,6 @@ else:
|
||||
_register_fake = register_fake_no_op
|
||||
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# TODO(aryan): Add support for the following:
|
||||
# - Sage Attention++
|
||||
# - block sparse, radial and other attention methods
|
||||
@@ -266,13 +319,21 @@ class _HubKernelConfig:
|
||||
function_attr: str
|
||||
revision: str | None = None
|
||||
kernel_fn: Callable | None = None
|
||||
wrapped_forward_attr: str | None = None
|
||||
wrapped_backward_attr: str | None = None
|
||||
wrapped_forward_fn: Callable | None = None
|
||||
wrapped_backward_fn: Callable | None = None
|
||||
|
||||
|
||||
# Registry for hub-based attention kernels
|
||||
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# TODO: temporary revision for now. Remove when merged upstream into `main`.
|
||||
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
function_attr="flash_attn_func",
|
||||
revision="fake-ops-return-probs",
|
||||
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
|
||||
),
|
||||
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn3",
|
||||
@@ -280,7 +341,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
|
||||
# revision="fake-ops-return-probs",
|
||||
),
|
||||
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
|
||||
repo_id="kernels-community/flash-attn2",
|
||||
function_attr="flash_attn_func",
|
||||
revision=None,
|
||||
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
|
||||
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
|
||||
),
|
||||
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
|
||||
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
|
||||
@@ -605,22 +670,39 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
||||
|
||||
|
||||
# ===== Helpers for downloading kernels =====
|
||||
def _resolve_kernel_attr(module, attr_path: str):
|
||||
target = module
|
||||
for attr in attr_path.split("."):
|
||||
if not hasattr(target, attr):
|
||||
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
|
||||
target = getattr(target, attr)
|
||||
return target
|
||||
|
||||
|
||||
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
|
||||
if backend not in _HUB_KERNELS_REGISTRY:
|
||||
return
|
||||
config = _HUB_KERNELS_REGISTRY[backend]
|
||||
|
||||
if config.kernel_fn is not None:
|
||||
needs_kernel = config.kernel_fn is None
|
||||
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
|
||||
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
|
||||
|
||||
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
|
||||
return
|
||||
|
||||
try:
|
||||
from kernels import get_kernel
|
||||
|
||||
kernel_module = get_kernel(config.repo_id, revision=config.revision)
|
||||
kernel_func = getattr(kernel_module, config.function_attr)
|
||||
if needs_kernel:
|
||||
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
|
||||
|
||||
# Cache the downloaded kernel function in the config object
|
||||
config.kernel_fn = kernel_func
|
||||
if needs_wrapped_forward:
|
||||
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
|
||||
|
||||
if needs_wrapped_backward:
|
||||
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
|
||||
@@ -651,7 +733,7 @@ def _wrapped_flash_attn_3(
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# Hardcoded for now because pytorch does not support tuple/int type hints
|
||||
window_size = (-1, -1)
|
||||
out, lse, *_ = flash_attn_3_func(
|
||||
result = flash_attn_3_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
@@ -668,7 +750,9 @@ def _wrapped_flash_attn_3(
|
||||
pack_gqa=pack_gqa,
|
||||
deterministic=deterministic,
|
||||
sm_margin=sm_margin,
|
||||
return_attn_probs=True,
|
||||
)
|
||||
out, lse, *_ = result
|
||||
lse = lse.permute(0, 2, 1)
|
||||
return out, lse
|
||||
|
||||
@@ -1071,6 +1155,258 @@ def _flash_attention_backward_op(
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_forward_fn is None or wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
window_size = (-1, -1)
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
|
||||
)
|
||||
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
_ = wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
lse,
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.dropout_p,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.alibi_slopes,
|
||||
ctx.deterministic,
|
||||
rng_state,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _flash_attention_3_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
*,
|
||||
window_size: tuple[int, int] = (-1, -1),
|
||||
softcap: float = 0.0,
|
||||
num_splits: int = 1,
|
||||
pack_gqa: bool | None = None,
|
||||
deterministic: bool = False,
|
||||
sm_margin: int = 0,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if dropout_p != 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
|
||||
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||
wrapped_forward_fn = config.wrapped_forward_fn
|
||||
if wrapped_forward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
out, softmax_lse, *_ = wrapped_forward_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
None, # k_new, v_new
|
||||
None, # qv
|
||||
None, # out
|
||||
None,
|
||||
None,
|
||||
None, # cu_seqlens_q/k/k_new
|
||||
None,
|
||||
None, # seqused_q/k
|
||||
None,
|
||||
None, # max_seqlen_q/k
|
||||
None,
|
||||
None,
|
||||
None, # page_table, kv_batch_idx, leftpad_k
|
||||
None,
|
||||
None,
|
||||
None, # rotary_cos/sin, seqlens_rotary
|
||||
None,
|
||||
None,
|
||||
None, # q_descale, k_descale, v_descale
|
||||
scale,
|
||||
causal=is_causal,
|
||||
window_size_left=window_size[0],
|
||||
window_size_right=window_size[1],
|
||||
attention_chunk=0,
|
||||
softcap=softcap,
|
||||
num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
sm_margin=sm_margin,
|
||||
)
|
||||
|
||||
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
|
||||
|
||||
if _save_ctx:
|
||||
ctx.save_for_backward(query, key, value, out, softmax_lse)
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
ctx.window_size = window_size
|
||||
ctx.softcap = softcap
|
||||
ctx.deterministic = deterministic
|
||||
ctx.sm_margin = sm_margin
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _flash_attention_3_hub_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
|
||||
wrapped_backward_fn = config.wrapped_backward_fn
|
||||
if wrapped_backward_fn is None:
|
||||
raise RuntimeError(
|
||||
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
|
||||
"for context parallel execution."
|
||||
)
|
||||
|
||||
query, key, value, out, softmax_lse = ctx.saved_tensors
|
||||
grad_query = torch.empty_like(query)
|
||||
grad_key = torch.empty_like(key)
|
||||
grad_value = torch.empty_like(value)
|
||||
|
||||
wrapped_backward_fn(
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
softmax_lse,
|
||||
None,
|
||||
None, # cu_seqlens_q, cu_seqlens_k
|
||||
None,
|
||||
None, # seqused_q, seqused_k
|
||||
None,
|
||||
None, # max_seqlen_q, max_seqlen_k
|
||||
grad_query,
|
||||
grad_key,
|
||||
grad_value,
|
||||
ctx.scale,
|
||||
ctx.is_causal,
|
||||
ctx.window_size[0],
|
||||
ctx.window_size[1],
|
||||
ctx.softcap,
|
||||
ctx.deterministic,
|
||||
ctx.sm_margin,
|
||||
)
|
||||
|
||||
grad_query = grad_query[..., : grad_out.shape[-1]]
|
||||
grad_key = grad_key[..., : grad_out.shape[-1]]
|
||||
grad_value = grad_value[..., : grad_out.shape[-1]]
|
||||
|
||||
return grad_query, grad_key, grad_value
|
||||
|
||||
|
||||
def _sage_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1109,6 +1445,46 @@ def _sage_attention_forward_op(
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_hub_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: torch.Tensor | None = None,
|
||||
dropout_p: float = 0.0,
|
||||
is_causal: bool = False,
|
||||
scale: float | None = None,
|
||||
enable_gqa: bool = False,
|
||||
return_lse: bool = False,
|
||||
_save_ctx: bool = True,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
):
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
|
||||
if dropout_p > 0.0:
|
||||
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
tensor_layout="NHD",
|
||||
is_causal=is_causal,
|
||||
sm_scale=scale,
|
||||
return_lse=return_lse,
|
||||
)
|
||||
|
||||
lse = None
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
lse = lse.permute(0, 2, 1).contiguous()
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
def _sage_attention_backward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
grad_out: torch.Tensor,
|
||||
@@ -1117,6 +1493,26 @@ def _sage_attention_backward_op(
|
||||
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
|
||||
|
||||
|
||||
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
|
||||
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
|
||||
if attn_mask is not None and torch.all(attn_mask != 0):
|
||||
attn_mask = None
|
||||
|
||||
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
|
||||
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
|
||||
if (
|
||||
attn_mask is not None
|
||||
and attn_mask.ndim == 2
|
||||
and attn_mask.shape[0] == query.shape[0]
|
||||
and attn_mask.shape[1] == key.shape[1]
|
||||
):
|
||||
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
|
||||
attn_mask = ~attn_mask.to(torch.bool)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
|
||||
|
||||
return attn_mask
|
||||
|
||||
|
||||
def _npu_attention_forward_op(
|
||||
ctx: torch.autograd.function.FunctionCtx,
|
||||
query: torch.Tensor,
|
||||
@@ -1134,11 +1530,14 @@ def _npu_attention_forward_op(
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
|
||||
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
atten_mask=attn_mask,
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
@@ -1942,7 +2341,7 @@ def _flash_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.FLASH_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -1960,17 +2359,35 @@ def _flash_attention_hub(
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
dropout_p=dropout_p,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
dropout_p,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_flash_attention_hub_forward_op,
|
||||
backward_op=_flash_attention_hub_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@@ -2117,7 +2534,7 @@ def _flash_attention_3(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName._FLASH_3_HUB,
|
||||
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _flash_attention_3_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2132,33 +2549,68 @@ def _flash_attention_3_hub(
|
||||
return_attn_probs: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if _parallel_config:
|
||||
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
|
||||
|
||||
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
if _parallel_config is None:
|
||||
out = func(
|
||||
q=query,
|
||||
k=key,
|
||||
v=value,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
qv=None,
|
||||
q_descale=None,
|
||||
k_descale=None,
|
||||
v_descale=None,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
|
||||
forward_op = functools.partial(
|
||||
_flash_attention_3_hub_forward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
return_attn_probs=return_attn_probs,
|
||||
)
|
||||
# When `return_attn_probs` is True, the above returns a tuple of
|
||||
# actual outputs and lse.
|
||||
return (out[0], out[1]) if return_attn_probs else out
|
||||
backward_op = functools.partial(
|
||||
_flash_attention_3_hub_backward_op,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
num_splits=1,
|
||||
pack_gqa=None,
|
||||
deterministic=deterministic,
|
||||
sm_margin=0,
|
||||
)
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_attn_probs,
|
||||
forward_op=forward_op,
|
||||
backward_op=backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_attn_probs:
|
||||
out, lse = out
|
||||
return out, lse
|
||||
|
||||
return out
|
||||
|
||||
|
||||
@_AttentionBackendRegistry.register(
|
||||
@@ -2251,7 +2703,7 @@ def _flash_varlen_attention_3(
|
||||
key_packed = torch.cat(key_valid, dim=0)
|
||||
value_packed = torch.cat(value_valid, dim=0)
|
||||
|
||||
out, lse, *_ = flash_attn_3_varlen_func(
|
||||
result = flash_attn_3_varlen_func(
|
||||
q=query_packed,
|
||||
k=key_packed,
|
||||
v=value_packed,
|
||||
@@ -2261,7 +2713,13 @@ def _flash_varlen_attention_3(
|
||||
max_seqlen_k=max_seqlen_k,
|
||||
softmax_scale=scale,
|
||||
causal=is_causal,
|
||||
return_attn_probs=return_lse,
|
||||
)
|
||||
if isinstance(result, tuple):
|
||||
out, lse, *_ = result
|
||||
else:
|
||||
out = result
|
||||
lse = None
|
||||
out = out.unflatten(0, (batch_size, -1))
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
@@ -2668,16 +3126,17 @@ def _native_npu_attention(
|
||||
return_lse: bool = False,
|
||||
_parallel_config: "ParallelConfig" | None = None,
|
||||
) -> torch.Tensor:
|
||||
if attn_mask is not None:
|
||||
raise ValueError("`attn_mask` is not supported for NPU attention")
|
||||
if return_lse:
|
||||
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
|
||||
if _parallel_config is None:
|
||||
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
|
||||
|
||||
out = npu_fusion_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
query.size(2), # num_heads
|
||||
atten_mask=attn_mask,
|
||||
input_layout="BSND",
|
||||
pse=None,
|
||||
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
|
||||
@@ -2692,7 +3151,7 @@ def _native_npu_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
attn_mask,
|
||||
dropout_p,
|
||||
None,
|
||||
scale,
|
||||
@@ -2789,7 +3248,7 @@ def _sage_attention(
|
||||
@_AttentionBackendRegistry.register(
|
||||
AttentionBackendName.SAGE_HUB,
|
||||
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
|
||||
supports_context_parallel=False,
|
||||
supports_context_parallel=True,
|
||||
)
|
||||
def _sage_attention_hub(
|
||||
query: torch.Tensor,
|
||||
@@ -2817,6 +3276,23 @@ def _sage_attention_hub(
|
||||
)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
else:
|
||||
out = _templated_context_parallel_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
None,
|
||||
0.0,
|
||||
is_causal,
|
||||
scale,
|
||||
False,
|
||||
return_lse,
|
||||
forward_op=_sage_attention_hub_forward_op,
|
||||
backward_op=_sage_attention_backward_op,
|
||||
_parallel_config=_parallel_config,
|
||||
)
|
||||
if return_lse:
|
||||
out, lse = out
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
|
||||
@@ -30,10 +30,126 @@ class AutoModel(ConfigMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise EnvironmentError(
|
||||
f"{self.__class__.__name__} is designed to be instantiated "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
||||
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`, "
|
||||
f"`{self.__class__.__name__}.from_config(config)`, or "
|
||||
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike | dict | None = None, **kwargs):
|
||||
r"""
|
||||
Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no
|
||||
pretrained weights are loaded).
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model
|
||||
configuration hosted on the Hub.
|
||||
- A path to a *directory* (for example `./my_model_directory`) containing a model configuration
|
||||
file.
|
||||
- A config dictionary.
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
||||
is not used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model configuration, overriding the cached version if
|
||||
it exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether to only load local model configuration files or not.
|
||||
token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use.
|
||||
trust_remote_code (`bool`, *optional*, defaults to `False`):
|
||||
Whether to trust remote code.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
||||
|
||||
Returns:
|
||||
A model object instantiated from the config with random weights.
|
||||
|
||||
Example:
|
||||
|
||||
```py
|
||||
from diffusers import AutoModel
|
||||
|
||||
model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
|
||||
```
|
||||
"""
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
trust_remote_code = kwargs.pop("trust_remote_code", False)
|
||||
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"revision",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
|
||||
|
||||
if pretrained_model_name_or_path_or_dict is None:
|
||||
raise ValueError(
|
||||
"Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
|
||||
pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict
|
||||
config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs)
|
||||
else:
|
||||
config = pretrained_model_name_or_path_or_dict
|
||||
pretrained_model_name_or_path = config.get("_name_or_path", None)
|
||||
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
|
||||
if has_remote_code and trust_remote_code:
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
model_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder=subfolder,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
**hub_kwargs,
|
||||
)
|
||||
else:
|
||||
if "_class_name" in config:
|
||||
class_name = config["_class_name"]
|
||||
library = "diffusers"
|
||||
elif "model_type" in config:
|
||||
class_name = "AutoModel"
|
||||
library = "transformers"
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Couldn't find a model class associated with the config: {config}. Make sure the config "
|
||||
"contains a `_class_name` or `model_type` key."
|
||||
)
|
||||
|
||||
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
|
||||
|
||||
model_cls, _ = get_class_obj_and_candidates(
|
||||
library_name=library,
|
||||
class_name=class_name,
|
||||
importable_classes=ALL_IMPORTABLE_CLASSES,
|
||||
pipelines=None,
|
||||
is_pipeline_module=False,
|
||||
)
|
||||
|
||||
if model_cls is None:
|
||||
raise ValueError(f"AutoModel can't find a model linked to {class_name}.")
|
||||
|
||||
return model_cls.from_config(config, **kwargs)
|
||||
|
||||
@classmethod
|
||||
@validate_hf_hub_args
|
||||
def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = None, **kwargs):
|
||||
|
||||
@@ -18,6 +18,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
|
||||
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
||||
from .autoencoder_kl_wan import AutoencoderKLWan
|
||||
from .autoencoder_oobleck import AutoencoderOobleck
|
||||
from .autoencoder_rae import AutoencoderRAE
|
||||
from .autoencoder_tiny import AutoencoderTiny
|
||||
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
||||
from .vq_model import VQModel
|
||||
|
||||
692
src/diffusers/models/autoencoders/autoencoder_rae.py
Normal file
692
src/diffusers/models/autoencoders/autoencoder_rae.py
Normal file
@@ -0,0 +1,692 @@
|
||||
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from math import sqrt
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import BaseOutput, logging
|
||||
from ...utils.accelerate_utils import apply_forward_hook
|
||||
from ...utils.import_utils import is_transformers_available
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import (
|
||||
Dinov2WithRegistersConfig,
|
||||
Dinov2WithRegistersModel,
|
||||
SiglipVisionConfig,
|
||||
SiglipVisionModel,
|
||||
ViTMAEConfig,
|
||||
ViTMAEModel,
|
||||
)
|
||||
|
||||
from ..activations import get_activation
|
||||
from ..attention import AttentionMixin
|
||||
from ..attention_processor import Attention
|
||||
from ..embeddings import get_2d_sincos_pos_embed
|
||||
from ..modeling_utils import ModelMixin
|
||||
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-encoder forward functions
|
||||
# ---------------------------------------------------------------------------
|
||||
# Each function takes the raw transformers model + images and returns patch
|
||||
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
|
||||
|
||||
|
||||
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True)
|
||||
unused_token_num = 5 # 1 CLS + 4 register tokens
|
||||
return outputs.last_hidden_state[:, unused_token_num:]
|
||||
|
||||
|
||||
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
|
||||
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state
|
||||
|
||||
|
||||
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
h, w = images.shape[2], images.shape[3]
|
||||
patch_num = int(h * w // patch_size**2)
|
||||
if patch_num * patch_size**2 != h * w:
|
||||
raise ValueError("Image size should be divisible by patch size.")
|
||||
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
|
||||
outputs = model(images, noise, interpolate_pos_encoding=True)
|
||||
return outputs.last_hidden_state[:, 1:] # remove cls token
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Encoder construction helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_encoder(
|
||||
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
|
||||
) -> nn.Module:
|
||||
"""Build a frozen encoder from config (no pretrained download)."""
|
||||
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
|
||||
|
||||
if encoder_type == "dinov2":
|
||||
config = Dinov2WithRegistersConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=518,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = Dinov2WithRegistersModel(config)
|
||||
# RAE strips the final layernorm affine params (identity LN). Remove them from
|
||||
# the architecture so `from_pretrained` doesn't leave them on the meta device.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
elif encoder_type == "siglip2":
|
||||
config = SiglipVisionConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=256,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
)
|
||||
model = SiglipVisionModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.vision_model.post_layernorm.weight = None
|
||||
model.vision_model.post_layernorm.bias = None
|
||||
elif encoder_type == "mae":
|
||||
config = ViTMAEConfig(
|
||||
hidden_size=hidden_size,
|
||||
patch_size=patch_size,
|
||||
image_size=224,
|
||||
num_attention_heads=num_attention_heads,
|
||||
num_hidden_layers=num_hidden_layers,
|
||||
mask_ratio=0.0,
|
||||
)
|
||||
model = ViTMAEModel(config)
|
||||
# See dinov2 comment above.
|
||||
model.layernorm.weight = None
|
||||
model.layernorm.bias = None
|
||||
else:
|
||||
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
|
||||
|
||||
model.requires_grad_(False)
|
||||
return model
|
||||
|
||||
|
||||
_ENCODER_FORWARD_FNS = {
|
||||
"dinov2": _dinov2_encoder_forward,
|
||||
"siglip2": _siglip2_encoder_forward,
|
||||
"mae": _mae_encoder_forward,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class RAEDecoderOutput(BaseOutput):
|
||||
"""
|
||||
Output of `RAEDecoder`.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`):
|
||||
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
|
||||
"""
|
||||
|
||||
logits: torch.Tensor
|
||||
|
||||
|
||||
class ViTMAEIntermediate(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(hidden_size, intermediate_size)
|
||||
self.intermediate_act_fn = get_activation(hidden_act)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAEOutput(nn.Module):
|
||||
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
|
||||
super().__init__()
|
||||
self.dense = nn.Linear(intermediate_size, hidden_size)
|
||||
self.dropout = nn.Dropout(hidden_dropout_prob)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.dense(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = hidden_states + input_tensor
|
||||
return hidden_states
|
||||
|
||||
|
||||
class ViTMAELayer(nn.Module):
|
||||
"""
|
||||
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
hidden_size: int,
|
||||
num_attention_heads: int,
|
||||
intermediate_size: int,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
if hidden_size % num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
|
||||
)
|
||||
self.attention = Attention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_attention_heads,
|
||||
dim_head=hidden_size // num_attention_heads,
|
||||
dropout=attention_probs_dropout_prob,
|
||||
bias=qkv_bias,
|
||||
)
|
||||
self.intermediate = ViTMAEIntermediate(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
|
||||
)
|
||||
self.output = ViTMAEOutput(
|
||||
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
|
||||
)
|
||||
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attention_output = self.attention(self.layernorm_before(hidden_states))
|
||||
hidden_states = attention_output + hidden_states
|
||||
|
||||
layer_output = self.layernorm_after(hidden_states)
|
||||
layer_output = self.intermediate(layer_output)
|
||||
layer_output = self.output(layer_output, hidden_states)
|
||||
return layer_output
|
||||
|
||||
|
||||
class RAEDecoder(nn.Module):
|
||||
"""Lightweight RAE decoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
num_patches: int = 256,
|
||||
patch_size: int = 16,
|
||||
num_channels: int = 3,
|
||||
image_size: int = 256,
|
||||
qkv_bias: bool = True,
|
||||
layer_norm_eps: float = 1e-12,
|
||||
hidden_dropout_prob: float = 0.0,
|
||||
attention_probs_dropout_prob: float = 0.0,
|
||||
hidden_act: str = "gelu",
|
||||
):
|
||||
super().__init__()
|
||||
self.decoder_hidden_size = decoder_hidden_size
|
||||
self.patch_size = patch_size
|
||||
self.num_channels = num_channels
|
||||
self.image_size = image_size
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
|
||||
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
|
||||
|
||||
self.decoder_layers = nn.ModuleList(
|
||||
[
|
||||
ViTMAELayer(
|
||||
hidden_size=decoder_hidden_size,
|
||||
num_attention_heads=decoder_num_attention_heads,
|
||||
intermediate_size=decoder_intermediate_size,
|
||||
qkv_bias=qkv_bias,
|
||||
layer_norm_eps=layer_norm_eps,
|
||||
hidden_dropout_prob=hidden_dropout_prob,
|
||||
attention_probs_dropout_prob=attention_probs_dropout_prob,
|
||||
hidden_act=hidden_act,
|
||||
)
|
||||
for _ in range(decoder_num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
|
||||
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
self._initialize_weights(num_patches)
|
||||
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
|
||||
|
||||
def _initialize_weights(self, num_patches: int):
|
||||
# Skip initialization when parameters are on meta device (e.g. during
|
||||
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
|
||||
# The weights are initialized.
|
||||
if self.decoder_pos_embed.device.type == "meta":
|
||||
return
|
||||
|
||||
grid_size = int(num_patches**0.5)
|
||||
pos_embed = get_2d_sincos_pos_embed(
|
||||
self.decoder_pos_embed.shape[-1],
|
||||
grid_size,
|
||||
cls_token=True,
|
||||
extra_tokens=1,
|
||||
output_type="pt",
|
||||
device=self.decoder_pos_embed.device,
|
||||
)
|
||||
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
|
||||
|
||||
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
|
||||
embeddings_positions = embeddings.shape[1] - 1
|
||||
num_positions = self.decoder_pos_embed.shape[1] - 1
|
||||
|
||||
class_pos_embed = self.decoder_pos_embed[:, 0, :]
|
||||
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
|
||||
dim = self.decoder_pos_embed.shape[-1]
|
||||
|
||||
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
|
||||
patch_pos_embed = F.interpolate(
|
||||
patch_pos_embed,
|
||||
scale_factor=(1, embeddings_positions / num_positions),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
||||
|
||||
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
|
||||
b, l, c = x.shape
|
||||
if l == self.num_patches:
|
||||
return x
|
||||
h = w = int(l**0.5)
|
||||
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
||||
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
|
||||
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
|
||||
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
|
||||
return x
|
||||
|
||||
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
|
||||
patch_size, num_channels = self.patch_size, self.num_channels
|
||||
original_image_size = (
|
||||
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
|
||||
)
|
||||
original_height, original_width = original_image_size
|
||||
num_patches_h = original_height // patch_size
|
||||
num_patches_w = original_width // patch_size
|
||||
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
|
||||
raise ValueError(
|
||||
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
|
||||
)
|
||||
|
||||
batch_size = patchified_pixel_values.shape[0]
|
||||
patchified_pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_patches_h,
|
||||
num_patches_w,
|
||||
patch_size,
|
||||
patch_size,
|
||||
num_channels,
|
||||
)
|
||||
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
|
||||
pixel_values = patchified_pixel_values.reshape(
|
||||
batch_size,
|
||||
num_channels,
|
||||
num_patches_h * patch_size,
|
||||
num_patches_w * patch_size,
|
||||
)
|
||||
return pixel_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
*,
|
||||
interpolate_pos_encoding: bool = False,
|
||||
drop_cls_token: bool = False,
|
||||
return_dict: bool = True,
|
||||
) -> RAEDecoderOutput | tuple[torch.Tensor]:
|
||||
x = self.decoder_embed(hidden_states)
|
||||
if drop_cls_token:
|
||||
x_ = x[:, 1:, :]
|
||||
x_ = self.interpolate_latent(x_)
|
||||
else:
|
||||
x_ = self.interpolate_latent(x)
|
||||
|
||||
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
|
||||
x = torch.cat([cls_token, x_], dim=1)
|
||||
|
||||
if interpolate_pos_encoding:
|
||||
if not drop_cls_token:
|
||||
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
|
||||
decoder_pos_embed = self.interpolate_pos_encoding(x)
|
||||
else:
|
||||
decoder_pos_embed = self.decoder_pos_embed
|
||||
|
||||
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
for layer_module in self.decoder_layers:
|
||||
hidden_states = layer_module(hidden_states)
|
||||
|
||||
hidden_states = self.decoder_norm(hidden_states)
|
||||
logits = self.decoder_pred(hidden_states)
|
||||
logits = logits[:, 1:, :]
|
||||
|
||||
if not return_dict:
|
||||
return (logits,)
|
||||
return RAEDecoderOutput(logits=logits)
|
||||
|
||||
|
||||
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
|
||||
r"""
|
||||
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
|
||||
|
||||
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
|
||||
images from learned representations.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
|
||||
all models (such as downloading or saving).
|
||||
|
||||
Args:
|
||||
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
|
||||
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
|
||||
encoder_hidden_size (`int`, *optional*, defaults to `768`):
|
||||
Hidden size of the encoder model.
|
||||
encoder_patch_size (`int`, *optional*, defaults to `14`):
|
||||
Patch size of the encoder model.
|
||||
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
|
||||
Number of hidden layers in the encoder model.
|
||||
patch_size (`int`, *optional*, defaults to `16`):
|
||||
Decoder patch size (used for unpatchify and decoder head).
|
||||
encoder_input_size (`int`, *optional*, defaults to `224`):
|
||||
Input size expected by the encoder.
|
||||
image_size (`int`, *optional*):
|
||||
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
|
||||
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
|
||||
encoder_patch_size) ** 2`.
|
||||
num_channels (`int`, *optional*, defaults to `3`):
|
||||
Number of input/output channels.
|
||||
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
|
||||
Channel-wise mean for encoder input normalization (ImageNet defaults).
|
||||
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
|
||||
Channel-wise std for encoder input normalization (ImageNet defaults).
|
||||
latents_mean (`list` or `tuple`, *optional*):
|
||||
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
|
||||
lists.
|
||||
latents_std (`list` or `tuple`, *optional*):
|
||||
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
|
||||
config-serializable lists.
|
||||
noise_tau (`float`, *optional*, defaults to `0.0`):
|
||||
Noise level for training (adds noise to latents during training).
|
||||
reshape_to_2d (`bool`, *optional*, defaults to `True`):
|
||||
Whether to reshape latents to 2D (B, C, H, W) format.
|
||||
use_encoder_loss (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use encoder hidden states in the loss (for advanced training).
|
||||
"""
|
||||
|
||||
# NOTE: gradient checkpointing is not wired up for this model yet.
|
||||
_supports_gradient_checkpointing = False
|
||||
_no_split_modules = ["ViTMAELayer"]
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type: str = "dinov2",
|
||||
encoder_hidden_size: int = 768,
|
||||
encoder_patch_size: int = 14,
|
||||
encoder_num_hidden_layers: int = 12,
|
||||
decoder_hidden_size: int = 512,
|
||||
decoder_num_hidden_layers: int = 8,
|
||||
decoder_num_attention_heads: int = 16,
|
||||
decoder_intermediate_size: int = 2048,
|
||||
patch_size: int = 16,
|
||||
encoder_input_size: int = 224,
|
||||
image_size: int | None = None,
|
||||
num_channels: int = 3,
|
||||
encoder_norm_mean: list | None = None,
|
||||
encoder_norm_std: list | None = None,
|
||||
latents_mean: list | tuple | torch.Tensor | None = None,
|
||||
latents_std: list | tuple | torch.Tensor | None = None,
|
||||
noise_tau: float = 0.0,
|
||||
reshape_to_2d: bool = True,
|
||||
use_encoder_loss: bool = False,
|
||||
scaling_factor: float = 1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if encoder_type not in _ENCODER_FORWARD_FNS:
|
||||
raise ValueError(
|
||||
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
|
||||
)
|
||||
|
||||
if encoder_input_size % encoder_patch_size != 0:
|
||||
raise ValueError(
|
||||
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
|
||||
)
|
||||
|
||||
decoder_patch_size = patch_size
|
||||
if decoder_patch_size <= 0:
|
||||
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
|
||||
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
grid = int(sqrt(num_patches))
|
||||
if grid * grid != num_patches:
|
||||
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
|
||||
|
||||
derived_image_size = decoder_patch_size * grid
|
||||
if image_size is None:
|
||||
image_size = derived_image_size
|
||||
else:
|
||||
image_size = int(image_size)
|
||||
if image_size != derived_image_size:
|
||||
raise ValueError(
|
||||
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
|
||||
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
|
||||
)
|
||||
|
||||
def _to_config_compatible(value: Any) -> Any:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().cpu().tolist()
|
||||
if isinstance(value, tuple):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
if isinstance(value, list):
|
||||
return [_to_config_compatible(v) for v in value]
|
||||
return value
|
||||
|
||||
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.detach().clone()
|
||||
return torch.tensor(value, dtype=torch.float32)
|
||||
|
||||
latents_std_tensor = _as_optional_tensor(latents_std)
|
||||
|
||||
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
|
||||
self.register_to_config(
|
||||
latents_mean=_to_config_compatible(latents_mean),
|
||||
latents_std=_to_config_compatible(latents_std),
|
||||
)
|
||||
|
||||
# Frozen representation encoder (built from config, no downloads)
|
||||
self.encoder: nn.Module = _build_encoder(
|
||||
encoder_type=encoder_type,
|
||||
hidden_size=encoder_hidden_size,
|
||||
patch_size=encoder_patch_size,
|
||||
num_hidden_layers=encoder_num_hidden_layers,
|
||||
)
|
||||
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
|
||||
num_patches = (encoder_input_size // encoder_patch_size) ** 2
|
||||
|
||||
# Encoder input normalization stats (ImageNet defaults)
|
||||
if encoder_norm_mean is None:
|
||||
encoder_norm_mean = [0.485, 0.456, 0.406]
|
||||
if encoder_norm_std is None:
|
||||
encoder_norm_std = [0.229, 0.224, 0.225]
|
||||
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
|
||||
|
||||
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
|
||||
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
|
||||
|
||||
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
|
||||
latents_mean_tensor = _as_optional_tensor(latents_mean)
|
||||
if latents_mean_tensor is None:
|
||||
latents_mean_tensor = torch.zeros(1)
|
||||
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
|
||||
|
||||
if latents_std_tensor is None:
|
||||
latents_std_tensor = torch.ones(1)
|
||||
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
|
||||
|
||||
# ViT-MAE style decoder
|
||||
self.decoder = RAEDecoder(
|
||||
hidden_size=int(encoder_hidden_size),
|
||||
decoder_hidden_size=int(decoder_hidden_size),
|
||||
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
|
||||
decoder_num_attention_heads=int(decoder_num_attention_heads),
|
||||
decoder_intermediate_size=int(decoder_intermediate_size),
|
||||
num_patches=int(num_patches),
|
||||
patch_size=int(decoder_patch_size),
|
||||
num_channels=int(num_channels),
|
||||
image_size=int(image_size),
|
||||
)
|
||||
|
||||
self.num_patches = int(num_patches)
|
||||
self.decoder_patch_size = int(decoder_patch_size)
|
||||
self.decoder_image_size = int(image_size)
|
||||
|
||||
# Slicing support (batch dimension) similar to other diffusers autoencoders
|
||||
self.use_slicing = False
|
||||
|
||||
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
# Per-sample random sigma in [0, noise_tau]
|
||||
noise_sigma = self.config.noise_tau * torch.rand(
|
||||
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
|
||||
)
|
||||
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
|
||||
|
||||
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
|
||||
_, _, h, w = x.shape
|
||||
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=(self.config.encoder_input_size, self.config.encoder_input_size),
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return (x - mean) / std
|
||||
|
||||
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
|
||||
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
|
||||
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
|
||||
return x * std + mean
|
||||
|
||||
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return (z - latents_mean) / (latents_std + 1e-5)
|
||||
|
||||
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
|
||||
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
|
||||
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
|
||||
return z * (latents_std + 1e-5) + latents_mean
|
||||
|
||||
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
|
||||
x = self._resize_and_normalize(x)
|
||||
|
||||
if self.config.encoder_type == "mae":
|
||||
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
|
||||
else:
|
||||
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
|
||||
|
||||
if self.training and self.config.noise_tau > 0:
|
||||
tokens = self._noising(tokens, generator=generator)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, n, c = tokens.shape
|
||||
side = int(sqrt(n))
|
||||
if side * side != n:
|
||||
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
|
||||
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
|
||||
else:
|
||||
z = tokens
|
||||
|
||||
z = self._normalize_latents(z)
|
||||
|
||||
# Follow diffusers convention: optionally scale latents for diffusion
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z * self.config.scaling_factor
|
||||
|
||||
return z
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> EncoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
|
||||
else:
|
||||
latents = self._encode(x, generator=generator)
|
||||
|
||||
if not return_dict:
|
||||
return (latents,)
|
||||
return EncoderOutput(latent=latents)
|
||||
|
||||
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
||||
# Undo scaling factor if applied at encode time
|
||||
if self.config.scaling_factor != 1.0:
|
||||
z = z / self.config.scaling_factor
|
||||
|
||||
z = self._denormalize_latents(z)
|
||||
|
||||
if self.config.reshape_to_2d:
|
||||
b, c, h, w = z.shape
|
||||
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
|
||||
else:
|
||||
tokens = z
|
||||
|
||||
logits = self.decoder(tokens, return_dict=True).logits
|
||||
x_rec = self.decoder.unpatchify(logits)
|
||||
x_rec = self._denormalize_image(x_rec)
|
||||
return x_rec
|
||||
|
||||
@apply_forward_hook
|
||||
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
if self.use_slicing and z.shape[0] > 1:
|
||||
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
|
||||
else:
|
||||
decoded = self._decode(z)
|
||||
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
|
||||
def forward(
|
||||
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
|
||||
) -> DecoderOutput | tuple[torch.Tensor]:
|
||||
latents = self.encode(sample, return_dict=False, generator=generator)[0]
|
||||
decoded = self.decode(latents, return_dict=False)[0]
|
||||
if not return_dict:
|
||||
return (decoded,)
|
||||
return DecoderOutput(sample=decoded)
|
||||
@@ -191,7 +191,12 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
dim=1,
|
||||
)
|
||||
|
||||
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
|
||||
if condition_mask is not None:
|
||||
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
|
||||
else:
|
||||
control_hidden_states = torch.cat(
|
||||
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
|
||||
)
|
||||
|
||||
padding_mask_resized = transforms.functional.resize(
|
||||
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST
|
||||
|
||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
from .transformer_flux import FluxTransformer2DModel
|
||||
from .transformer_flux2 import Flux2Transformer2DModel
|
||||
from .transformer_glm_image import GlmImageTransformer2DModel
|
||||
from .transformer_helios import HeliosTransformer3DModel
|
||||
from .transformer_hidream_image import HiDreamImageTransformer2DModel
|
||||
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
||||
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel
|
||||
|
||||
@@ -424,7 +424,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None,
|
||||
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
temb_mod: torch.Tensor,
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
split_hidden_states: bool = False,
|
||||
@@ -436,7 +436,7 @@ class Flux2SingleTransformerBlock(nn.Module):
|
||||
text_seq_len = encoder_hidden_states.shape[1]
|
||||
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
||||
|
||||
mod_shift, mod_scale, mod_gate = temb_mod_params
|
||||
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
|
||||
|
||||
norm_hidden_states = self.norm(hidden_states)
|
||||
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
|
||||
@@ -498,16 +498,18 @@ class Flux2TransformerBlock(nn.Module):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
|
||||
temb_mod_img: torch.Tensor,
|
||||
temb_mod_txt: torch.Tensor,
|
||||
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
joint_attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
joint_attention_kwargs = joint_attention_kwargs or {}
|
||||
|
||||
# Modulation parameters shape: [1, 1, self.dim]
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
|
||||
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
|
||||
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
|
||||
temb_mod_txt, 2
|
||||
)
|
||||
|
||||
# Img stream
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
@@ -627,15 +629,19 @@ class Flux2Modulation(nn.Module):
|
||||
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
def forward(self, temb: torch.Tensor) -> torch.Tensor:
|
||||
mod = self.act_fn(temb)
|
||||
mod = self.linear(mod)
|
||||
return mod
|
||||
|
||||
@staticmethod
|
||||
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
|
||||
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
|
||||
if mod.ndim == 2:
|
||||
mod = mod.unsqueeze(1)
|
||||
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
|
||||
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
|
||||
# Return tuple of 3-tuples of modulation params shift/scale/gate
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
|
||||
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
|
||||
|
||||
|
||||
class Flux2Transformer2DModel(
|
||||
@@ -824,7 +830,7 @@ class Flux2Transformer2DModel(
|
||||
|
||||
double_stream_mod_img = self.double_stream_modulation_img(temb)
|
||||
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
|
||||
single_stream_mod = self.single_stream_modulation(temb)[0]
|
||||
single_stream_mod = self.single_stream_modulation(temb)
|
||||
|
||||
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
@@ -861,8 +867,8 @@ class Flux2Transformer2DModel(
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb_mod_params_img=double_stream_mod_img,
|
||||
temb_mod_params_txt=double_stream_mod_txt,
|
||||
temb_mod_img=double_stream_mod_img,
|
||||
temb_mod_txt=double_stream_mod_txt,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
@@ -884,7 +890,7 @@ class Flux2Transformer2DModel(
|
||||
hidden_states = block(
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
temb_mod_params=single_stream_mod,
|
||||
temb_mod=single_stream_mod,
|
||||
image_rotary_emb=concat_rotary_emb,
|
||||
joint_attention_kwargs=joint_attention_kwargs,
|
||||
)
|
||||
|
||||
814
src/diffusers/models/transformers/transformer_helios.py
Normal file
814
src/diffusers/models/transformers/transformer_helios.py
Normal file
@@ -0,0 +1,814 @@
|
||||
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
||||
from ...utils import apply_lora_scale, logging
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..cache_utils import CacheMixin
|
||||
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import FP32LayerNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def pad_for_3d_conv(x, kernel_size):
|
||||
b, c, t, h, w = x.shape
|
||||
pt, ph, pw = kernel_size
|
||||
pad_t = (pt - (t % pt)) % pt
|
||||
pad_h = (ph - (h % ph)) % ph
|
||||
pad_w = (pw - (w % pw)) % pw
|
||||
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
|
||||
|
||||
|
||||
def center_down_sample_3d(x, kernel_size):
|
||||
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
|
||||
|
||||
|
||||
def apply_rotary_emb_transposed(
|
||||
hidden_states: torch.Tensor,
|
||||
freqs_cis: torch.Tensor,
|
||||
):
|
||||
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
|
||||
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
|
||||
out = torch.empty_like(hidden_states)
|
||||
out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
|
||||
out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
|
||||
return out.type_as(hidden_states)
|
||||
|
||||
|
||||
def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
|
||||
# encoder_hidden_states is only passed for cross-attention
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
|
||||
if attn.fused_projections:
|
||||
if not attn.is_cross_attention:
|
||||
# In self-attention layers, we can fuse the entire QKV projection into a single linear
|
||||
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
|
||||
else:
|
||||
# In cross-attention layers, we can only fuse the KV projections into a single linear
|
||||
query = attn.to_q(hidden_states)
|
||||
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
return query, key, value
|
||||
|
||||
|
||||
class HeliosOutputNorm(nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
|
||||
super().__init__()
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
|
||||
self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int):
|
||||
temb = temb[:, -original_context_length:, :]
|
||||
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
|
||||
shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device)
|
||||
hidden_states = hidden_states[:, -original_context_length:, :]
|
||||
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosAttnProcessor:
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError(
|
||||
"HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "HeliosAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
original_context_length: int = None,
|
||||
) -> torch.Tensor:
|
||||
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
|
||||
|
||||
query = attn.norm_q(query)
|
||||
key = attn.norm_k(key)
|
||||
|
||||
query = query.unflatten(2, (attn.heads, -1))
|
||||
key = key.unflatten(2, (attn.heads, -1))
|
||||
value = value.unflatten(2, (attn.heads, -1))
|
||||
|
||||
if rotary_emb is not None:
|
||||
query = apply_rotary_emb_transposed(query, rotary_emb)
|
||||
key = apply_rotary_emb_transposed(key, rotary_emb)
|
||||
|
||||
if not attn.is_cross_attention and attn.is_amplify_history:
|
||||
history_seq_len = hidden_states.shape[1] - original_context_length
|
||||
|
||||
if history_seq_len > 0:
|
||||
scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0)
|
||||
if attn.history_scale_mode == "per_head":
|
||||
scale_key = scale_key.view(1, 1, -1, 1)
|
||||
key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1)
|
||||
|
||||
hidden_states = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attention_mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=False,
|
||||
backend=self._attention_backend,
|
||||
# Reference: https://github.com/huggingface/diffusers/pull/12909
|
||||
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
|
||||
)
|
||||
hidden_states = hidden_states.flatten(2, 3)
|
||||
hidden_states = hidden_states.type_as(query)
|
||||
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosAttention(torch.nn.Module, AttentionModuleMixin):
|
||||
_default_processor_cls = HeliosAttnProcessor
|
||||
_available_processors = [HeliosAttnProcessor]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
eps: float = 1e-5,
|
||||
dropout: float = 0.0,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
cross_attention_dim_head: int | None = None,
|
||||
processor=None,
|
||||
is_cross_attention=None,
|
||||
is_amplify_history=False,
|
||||
history_scale_mode="per_head", # [scalar, per_head]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.inner_dim = dim_head * heads
|
||||
self.heads = heads
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
self.cross_attention_dim_head = cross_attention_dim_head
|
||||
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
|
||||
|
||||
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
|
||||
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
|
||||
self.to_out = torch.nn.ModuleList(
|
||||
[
|
||||
torch.nn.Linear(self.inner_dim, dim, bias=True),
|
||||
torch.nn.Dropout(dropout),
|
||||
]
|
||||
)
|
||||
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.add_k_proj = self.add_v_proj = None
|
||||
if added_kv_proj_dim is not None:
|
||||
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
|
||||
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
|
||||
|
||||
if is_cross_attention is not None:
|
||||
self.is_cross_attention = is_cross_attention
|
||||
else:
|
||||
self.is_cross_attention = cross_attention_dim_head is not None
|
||||
|
||||
self.set_processor(processor)
|
||||
|
||||
self.is_amplify_history = is_amplify_history
|
||||
if is_amplify_history:
|
||||
if history_scale_mode == "scalar":
|
||||
self.history_key_scale = nn.Parameter(torch.ones(1))
|
||||
elif history_scale_mode == "per_head":
|
||||
self.history_key_scale = nn.Parameter(torch.ones(heads))
|
||||
else:
|
||||
raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}")
|
||||
self.history_scale_mode = history_scale_mode
|
||||
self.max_scale = 10.0
|
||||
|
||||
def fuse_projections(self):
|
||||
if getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if not self.is_cross_attention:
|
||||
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_qkv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
else:
|
||||
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
|
||||
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
|
||||
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
|
||||
out_features, in_features = concatenated_weights.shape
|
||||
with torch.device("meta"):
|
||||
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
|
||||
self.to_added_kv.load_state_dict(
|
||||
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
|
||||
)
|
||||
|
||||
self.fused_projections = True
|
||||
|
||||
@torch.no_grad()
|
||||
def unfuse_projections(self):
|
||||
if not getattr(self, "fused_projections", False):
|
||||
return
|
||||
|
||||
if hasattr(self, "to_qkv"):
|
||||
delattr(self, "to_qkv")
|
||||
if hasattr(self, "to_kv"):
|
||||
delattr(self, "to_kv")
|
||||
if hasattr(self, "to_added_kv"):
|
||||
delattr(self, "to_added_kv")
|
||||
|
||||
self.fused_projections = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
attention_mask: torch.Tensor | None = None,
|
||||
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
|
||||
original_context_length: int = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
attention_mask,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
class HeliosTimeTextEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
time_freq_dim: int,
|
||||
time_proj_dim: int,
|
||||
text_embed_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
|
||||
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
|
||||
self.act_fn = nn.SiLU()
|
||||
self.time_proj = nn.Linear(dim, time_proj_dim)
|
||||
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor | None = None,
|
||||
is_return_encoder_hidden_states: bool = True,
|
||||
):
|
||||
timestep = self.timesteps_proj(timestep)
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
if encoder_hidden_states is not None and is_return_encoder_hidden_states:
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
return temb, timestep_proj, encoder_hidden_states
|
||||
|
||||
|
||||
class HeliosRotaryPosEmbed(nn.Module):
|
||||
def __init__(self, rope_dim, theta):
|
||||
super().__init__()
|
||||
self.DT, self.DY, self.DX = rope_dim
|
||||
self.theta = theta
|
||||
self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False)
|
||||
self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False)
|
||||
self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False)
|
||||
|
||||
def _get_freqs_base(self, dim):
|
||||
return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim))
|
||||
|
||||
@torch.no_grad()
|
||||
def get_frequency_batched(self, freqs_base, pos):
|
||||
freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos)
|
||||
freqs = freqs.repeat_interleave(2, dim=0)
|
||||
return freqs.cos(), freqs.sin()
|
||||
|
||||
@torch.no_grad()
|
||||
@lru_cache(maxsize=32)
|
||||
def _get_spatial_meshgrid(self, height, width, device_str):
|
||||
device = torch.device(device_str)
|
||||
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
|
||||
grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)
|
||||
grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij")
|
||||
return grid_y, grid_x
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, frame_indices, height, width, device):
|
||||
batch_size = frame_indices.shape[0]
|
||||
num_frames = frame_indices.shape[1]
|
||||
|
||||
frame_indices = frame_indices.to(device=device, dtype=torch.float32)
|
||||
grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))
|
||||
|
||||
grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width)
|
||||
grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)
|
||||
grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)
|
||||
|
||||
freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t)
|
||||
freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch)
|
||||
freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch)
|
||||
|
||||
result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0)
|
||||
|
||||
return result.permute(1, 0, 2, 3, 4)
|
||||
|
||||
|
||||
@maybe_allow_in_graph
|
||||
class HeliosTransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
ffn_dim: int,
|
||||
num_heads: int,
|
||||
qk_norm: str = "rms_norm_across_heads",
|
||||
cross_attn_norm: bool = False,
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
guidance_cross_attn: bool = False,
|
||||
is_amplify_history: bool = False,
|
||||
history_scale_mode: str = "per_head", # [scalar, per_head]
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# 1. Self-attention
|
||||
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
self.attn1 = HeliosAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
cross_attention_dim_head=None,
|
||||
processor=HeliosAttnProcessor(),
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
)
|
||||
|
||||
# 2. Cross-attention
|
||||
self.attn2 = HeliosAttention(
|
||||
dim=dim,
|
||||
heads=num_heads,
|
||||
dim_head=dim // num_heads,
|
||||
eps=eps,
|
||||
added_kv_proj_dim=added_kv_proj_dim,
|
||||
cross_attention_dim_head=dim // num_heads,
|
||||
processor=HeliosAttnProcessor(),
|
||||
)
|
||||
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
||||
|
||||
# 3. Feed-forward
|
||||
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
||||
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
||||
|
||||
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
||||
|
||||
# 4. Guidance cross-attention
|
||||
self.guidance_cross_attn = guidance_cross_attn
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
rotary_emb: torch.Tensor,
|
||||
original_context_length: int = None,
|
||||
) -> torch.Tensor:
|
||||
if temb.ndim == 4:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table.unsqueeze(0) + temb.float()
|
||||
).chunk(6, dim=2)
|
||||
# batch_size, seq_len, 1, inner_dim
|
||||
shift_msa = shift_msa.squeeze(2)
|
||||
scale_msa = scale_msa.squeeze(2)
|
||||
gate_msa = gate_msa.squeeze(2)
|
||||
c_shift_msa = c_shift_msa.squeeze(2)
|
||||
c_scale_msa = c_scale_msa.squeeze(2)
|
||||
c_gate_msa = c_gate_msa.squeeze(2)
|
||||
else:
|
||||
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
||||
self.scale_shift_table + temb.float()
|
||||
).chunk(6, dim=1)
|
||||
|
||||
# 1. Self-attention
|
||||
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
None,
|
||||
None,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
|
||||
|
||||
# 2. Cross-attention
|
||||
if self.guidance_cross_attn:
|
||||
history_seq_len = hidden_states.shape[1] - original_context_length
|
||||
|
||||
history_hidden_states, hidden_states = torch.split(
|
||||
hidden_states, [history_seq_len, original_context_length], dim=1
|
||||
)
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
None,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1)
|
||||
else:
|
||||
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states,
|
||||
None,
|
||||
None,
|
||||
original_context_length,
|
||||
)
|
||||
hidden_states = hidden_states + attn_output
|
||||
|
||||
# 3. Feed-forward
|
||||
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
||||
hidden_states
|
||||
)
|
||||
ff_output = self.ffn(norm_hidden_states)
|
||||
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class HeliosTransformer3DModel(
|
||||
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
|
||||
):
|
||||
r"""
|
||||
A Transformer model for video-like data used in the Helios model.
|
||||
|
||||
Args:
|
||||
patch_size (`tuple[int]`, defaults to `(1, 2, 2)`):
|
||||
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
||||
num_attention_heads (`int`, defaults to `40`):
|
||||
Fixed length for text embeddings.
|
||||
attention_head_dim (`int`, defaults to `128`):
|
||||
The number of channels in each head.
|
||||
in_channels (`int`, defaults to `16`):
|
||||
The number of channels in the input.
|
||||
out_channels (`int`, defaults to `16`):
|
||||
The number of channels in the output.
|
||||
text_dim (`int`, defaults to `512`):
|
||||
Input dimension for text embeddings.
|
||||
freq_dim (`int`, defaults to `256`):
|
||||
Dimension for sinusoidal time embeddings.
|
||||
ffn_dim (`int`, defaults to `13824`):
|
||||
Intermediate dimension in feed-forward network.
|
||||
num_layers (`int`, defaults to `40`):
|
||||
The number of layers of transformer blocks to use.
|
||||
window_size (`tuple[int]`, defaults to `(-1, -1)`):
|
||||
Window size for local attention (-1 indicates global attention).
|
||||
cross_attn_norm (`bool`, defaults to `True`):
|
||||
Enable cross-attention normalization.
|
||||
qk_norm (`bool`, defaults to `True`):
|
||||
Enable query/key normalization.
|
||||
eps (`float`, defaults to `1e-6`):
|
||||
Epsilon value for normalization layers.
|
||||
add_img_emb (`bool`, defaults to `False`):
|
||||
Whether to use img_emb.
|
||||
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
||||
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
_skip_layerwise_casting_patterns = [
|
||||
"patch_embedding",
|
||||
"patch_short",
|
||||
"patch_mid",
|
||||
"patch_long",
|
||||
"condition_embedder",
|
||||
"norm",
|
||||
]
|
||||
_no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"]
|
||||
_keep_in_fp32_modules = [
|
||||
"time_embedder",
|
||||
"scale_shift_table",
|
||||
"norm1",
|
||||
"norm2",
|
||||
"norm3",
|
||||
"history_key_scale",
|
||||
]
|
||||
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
||||
_repeated_blocks = ["HeliosTransformerBlock"]
|
||||
_cp_plan = {
|
||||
"blocks.0": {
|
||||
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.*": {
|
||||
"temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
|
||||
"rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
|
||||
},
|
||||
"blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
|
||||
}
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: tuple[int, ...] = (1, 2, 2),
|
||||
num_attention_heads: int = 40,
|
||||
attention_head_dim: int = 128,
|
||||
in_channels: int = 16,
|
||||
out_channels: int = 16,
|
||||
text_dim: int = 4096,
|
||||
freq_dim: int = 256,
|
||||
ffn_dim: int = 13824,
|
||||
num_layers: int = 40,
|
||||
cross_attn_norm: bool = True,
|
||||
qk_norm: str | None = "rms_norm_across_heads",
|
||||
eps: float = 1e-6,
|
||||
added_kv_proj_dim: int | None = None,
|
||||
rope_dim: tuple[int, ...] = (44, 42, 42),
|
||||
rope_theta: float = 10000.0,
|
||||
guidance_cross_attn: bool = True,
|
||||
zero_history_timestep: bool = True,
|
||||
has_multi_term_memory_patch: bool = True,
|
||||
is_amplify_history: bool = False,
|
||||
history_scale_mode: str = "per_head", # [scalar, per_head]
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
# 1. Patch & position embedding
|
||||
self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta)
|
||||
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
# 2. Initial Multi Term Memory Patch
|
||||
self.zero_history_timestep = zero_history_timestep
|
||||
if has_multi_term_memory_patch:
|
||||
self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
||||
self.patch_mid = nn.Conv3d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
kernel_size=tuple(2 * p for p in patch_size),
|
||||
stride=tuple(2 * p for p in patch_size),
|
||||
)
|
||||
self.patch_long = nn.Conv3d(
|
||||
in_channels,
|
||||
inner_dim,
|
||||
kernel_size=tuple(4 * p for p in patch_size),
|
||||
stride=tuple(4 * p for p in patch_size),
|
||||
)
|
||||
|
||||
# 3. Condition embeddings
|
||||
self.condition_embedder = HeliosTimeTextEmbedding(
|
||||
dim=inner_dim,
|
||||
time_freq_dim=freq_dim,
|
||||
time_proj_dim=inner_dim * 6,
|
||||
text_embed_dim=text_dim,
|
||||
)
|
||||
|
||||
# 4. Transformer blocks
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
HeliosTransformerBlock(
|
||||
inner_dim,
|
||||
ffn_dim,
|
||||
num_attention_heads,
|
||||
qk_norm,
|
||||
cross_attn_norm,
|
||||
eps,
|
||||
added_kv_proj_dim,
|
||||
guidance_cross_attn=guidance_cross_attn,
|
||||
is_amplify_history=is_amplify_history,
|
||||
history_scale_mode=history_scale_mode,
|
||||
)
|
||||
for _ in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 5. Output norm & projection
|
||||
self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False)
|
||||
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
@apply_lora_scale("attention_kwargs")
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.LongTensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
# ------------ Stage 1 ------------
|
||||
indices_hidden_states=None,
|
||||
indices_latents_history_short=None,
|
||||
indices_latents_history_mid=None,
|
||||
indices_latents_history_long=None,
|
||||
latents_history_short=None,
|
||||
latents_history_mid=None,
|
||||
latents_history_long=None,
|
||||
return_dict: bool = True,
|
||||
attention_kwargs: dict[str, Any] | None = None,
|
||||
) -> torch.Tensor | dict[str, torch.Tensor]:
|
||||
# 1. Input
|
||||
batch_size = hidden_states.shape[0]
|
||||
p_t, p_h, p_w = self.config.patch_size
|
||||
|
||||
# 2. Process noisy latents
|
||||
hidden_states = self.patch_embedding(hidden_states)
|
||||
_, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape
|
||||
|
||||
if indices_hidden_states is None:
|
||||
indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
||||
rotary_emb = self.rope(
|
||||
frame_indices=indices_hidden_states,
|
||||
height=post_patch_height,
|
||||
width=post_patch_width,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
rotary_emb = rotary_emb.flatten(2).transpose(1, 2)
|
||||
original_context_length = hidden_states.shape[1]
|
||||
|
||||
# 3. Process short history latents
|
||||
if latents_history_short is not None and indices_latents_history_short is not None:
|
||||
latents_history_short = self.patch_short(latents_history_short)
|
||||
_, _, _, H1, W1 = latents_history_short.shape
|
||||
latents_history_short = latents_history_short.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_short = self.rope(
|
||||
frame_indices=indices_latents_history_short,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_short.device,
|
||||
)
|
||||
rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)
|
||||
|
||||
# 4. Process mid history latents
|
||||
if latents_history_mid is not None and indices_latents_history_mid is not None:
|
||||
latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))
|
||||
latents_history_mid = self.patch_mid(latents_history_mid)
|
||||
latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_mid = self.rope(
|
||||
frame_indices=indices_latents_history_mid,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_mid.device,
|
||||
)
|
||||
rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))
|
||||
rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2))
|
||||
rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)
|
||||
|
||||
# 5. Process long history latents
|
||||
if latents_history_long is not None and indices_latents_history_long is not None:
|
||||
latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))
|
||||
latents_history_long = self.patch_long(latents_history_long)
|
||||
latents_history_long = latents_history_long.flatten(2).transpose(1, 2)
|
||||
|
||||
rotary_emb_history_long = self.rope(
|
||||
frame_indices=indices_latents_history_long,
|
||||
height=H1,
|
||||
width=W1,
|
||||
device=latents_history_long.device,
|
||||
)
|
||||
rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4))
|
||||
rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4))
|
||||
rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)
|
||||
|
||||
hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)
|
||||
rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)
|
||||
|
||||
history_context_length = hidden_states.shape[1] - original_context_length
|
||||
|
||||
if indices_hidden_states is not None and self.zero_history_timestep:
|
||||
timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device)
|
||||
temb_t0, timestep_proj_t0, _ = self.condition_embedder(
|
||||
timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False
|
||||
)
|
||||
temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1)
|
||||
timestep_proj_t0 = (
|
||||
timestep_proj_t0.unflatten(-1, (6, -1))
|
||||
.view(1, 6, 1, -1)
|
||||
.expand(batch_size, -1, history_context_length, -1)
|
||||
)
|
||||
|
||||
temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
|
||||
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
|
||||
|
||||
if indices_hidden_states is not None and not self.zero_history_timestep:
|
||||
main_repeat_size = hidden_states.shape[1]
|
||||
else:
|
||||
main_repeat_size = original_context_length
|
||||
temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)
|
||||
timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1)
|
||||
|
||||
if indices_hidden_states is not None and self.zero_history_timestep:
|
||||
temb = torch.cat([temb_t0, temb], dim=1)
|
||||
timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)
|
||||
|
||||
if timestep_proj.ndim == 4:
|
||||
timestep_proj = timestep_proj.permute(0, 2, 1, 3)
|
||||
|
||||
# 6. Transformer blocks
|
||||
hidden_states = hidden_states.contiguous()
|
||||
encoder_hidden_states = encoder_hidden_states.contiguous()
|
||||
rotary_emb = rotary_emb.contiguous()
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
for block in self.blocks:
|
||||
hidden_states = self._gradient_checkpointing_func(
|
||||
block,
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
hidden_states = block(
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
timestep_proj,
|
||||
rotary_emb,
|
||||
original_context_length,
|
||||
)
|
||||
|
||||
# 7. Normalization
|
||||
hidden_states = self.norm_out(hidden_states, temb, original_context_length)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
|
||||
# 8. Unpatchify
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
||||
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -164,7 +164,11 @@ def compute_text_seq_len_from_mask(
|
||||
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
|
||||
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
|
||||
has_active = encoder_hidden_states_mask.any(dim=1)
|
||||
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
|
||||
per_sample_len = torch.where(
|
||||
has_active,
|
||||
active_positions.max(dim=1).values + 1,
|
||||
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
|
||||
)
|
||||
return text_seq_len, per_sample_len, encoder_hidden_states_mask
|
||||
|
||||
|
||||
|
||||
@@ -21,21 +21,8 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["FluxTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"AUTO_BLOCKS_KONTEXT",
|
||||
"FLUX_KONTEXT_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"FluxAutoBeforeDenoiseStep",
|
||||
"FluxAutoBlocks",
|
||||
"FluxAutoDecodeStep",
|
||||
"FluxAutoDenoiseStep",
|
||||
"FluxKontextAutoBlocks",
|
||||
"FluxKontextAutoDenoiseStep",
|
||||
"FluxKontextBeforeDenoiseStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux"] = ["FluxAutoBlocks"]
|
||||
_import_structure["modular_blocks_flux_kontext"] = ["FluxKontextAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -45,21 +32,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import FluxTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
AUTO_BLOCKS_KONTEXT,
|
||||
FLUX_KONTEXT_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
FluxAutoBeforeDenoiseStep,
|
||||
FluxAutoBlocks,
|
||||
FluxAutoDecodeStep,
|
||||
FluxAutoDenoiseStep,
|
||||
FluxKontextAutoBlocks,
|
||||
FluxKontextAutoDenoiseStep,
|
||||
FluxKontextBeforeDenoiseStep,
|
||||
)
|
||||
from .modular_blocks_flux import FluxAutoBlocks
|
||||
from .modular_blocks_flux_kontext import FluxKontextAutoBlocks
|
||||
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -205,7 +205,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
|
||||
class FluxVaeEncoderStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -121,7 +121,7 @@ class FluxTextInputStep(ModularPipelineBlocks):
|
||||
|
||||
|
||||
# Adapted from `QwenImageAdditionalInputsStep`
|
||||
class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
class FluxAdditionalInputsStep(ModularPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
def __init__(
|
||||
@@ -243,7 +243,7 @@ class FluxInputsDynamicStep(ModularPipelineBlocks):
|
||||
return components, state
|
||||
|
||||
|
||||
class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
|
||||
class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
|
||||
@@ -256,7 +256,7 @@ class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
|
||||
continue
|
||||
|
||||
# 1. Calculate height/width from latents
|
||||
# Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
|
||||
# Unlike the `FluxAdditionalInputsStep`, we don't overwrite the `block.height` and `block.width`
|
||||
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
|
||||
if not hasattr(block_state, "image_height"):
|
||||
block_state.image_height = height
|
||||
@@ -303,6 +303,7 @@ class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
|
||||
class FluxKontextSetResolutionStep(ModularPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Determines the height and width to be used during the subsequent computations.\n"
|
||||
|
||||
@@ -1,446 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxKontextRoPEInputsStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxRoPEInputsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
|
||||
from .encoders import (
|
||||
FluxKontextProcessImagesInputStep,
|
||||
FluxProcessImagesInputStep,
|
||||
FluxTextEncoderStep,
|
||||
FluxVaeEncoderDynamicStep,
|
||||
)
|
||||
from .inputs import (
|
||||
FluxInputsDynamicStep,
|
||||
FluxKontextInputsDynamicStep,
|
||||
FluxKontextSetResolutionStep,
|
||||
FluxTextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# vae encoder (run before before_denoise)
|
||||
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
|
||||
[("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = FluxImg2ImgVaeEncoderBlocks.values()
|
||||
block_names = FluxImg2ImgVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxImg2ImgVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# Flux Kontext vae encoder (run before before_denoise)
|
||||
|
||||
FluxKontextVaeEncoderBlocks = InsertableDict(
|
||||
[("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = FluxKontextVaeEncoderBlocks.values()
|
||||
block_names = FluxKontextVaeEncoderBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
FluxBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = FluxBeforeDenoiseBlocks.values()
|
||||
block_names = FluxBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
|
||||
|
||||
|
||||
# before_denoise: img2img
|
||||
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
|
||||
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
|
||||
block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs for the denoise step for img2img task."
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: FluxKontext
|
||||
|
||||
FluxKontextBeforeDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = FluxKontextBeforeDenoiseBlocks.values()
|
||||
block_names = FluxKontextBeforeDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step\n"
|
||||
"for img2img/text2img task for Flux Kontext."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# denoise: text2image
|
||||
class FluxAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxDenoiseStep]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
# denoise: Flux Kontext
|
||||
|
||||
|
||||
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextDenoiseStep]
|
||||
block_names = ["denoise"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents for Flux Kontext. "
|
||||
"This is a auto pipeline block that works for text2image and img2img tasks."
|
||||
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
|
||||
)
|
||||
|
||||
|
||||
# decode: all task (text2img, img2img)
|
||||
class FluxAutoDecodeStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxDecodeStep]
|
||||
block_names = ["non-inpaint"]
|
||||
block_trigger_inputs = [None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
|
||||
|
||||
|
||||
# inputs: text2image/img2img
|
||||
FluxImg2ImgBlocks = InsertableDict(
|
||||
[("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
|
||||
)
|
||||
|
||||
|
||||
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
block_classes = FluxImg2ImgBlocks.values()
|
||||
block_names = FluxImg2ImgBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
class FluxAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# inputs: Flux Kontext
|
||||
|
||||
FluxKontextBlocks = InsertableDict(
|
||||
[
|
||||
("set_resolution", FluxKontextSetResolutionStep()),
|
||||
("text_inputs", FluxTextInputStep()),
|
||||
("additional_inputs", FluxKontextInputsDynamicStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextInputStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = FluxKontextBlocks.values()
|
||||
block_names = FluxKontextBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoInputStep(AutoPipelineBlocks):
|
||||
block_classes = [FluxKontextInputStep, FluxTextInputStep]
|
||||
# block_classes = [FluxKontextInputStep]
|
||||
block_names = ["img2img", "text2img"]
|
||||
# block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
# block_trigger_inputs = ["image_latents"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
|
||||
)
|
||||
|
||||
|
||||
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process. \n"
|
||||
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
|
||||
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
|
||||
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
|
||||
# Auto blocks (text2image and img2img)
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxAutoVaeEncoderStep()),
|
||||
("denoise", FluxCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("denoise", FluxKontextCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
|
||||
)
|
||||
|
||||
|
||||
class FluxKontextAutoBlocks(FluxAutoBlocks):
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = AUTO_BLOCKS_KONTEXT.values()
|
||||
block_names = AUTO_BLOCKS_KONTEXT.keys()
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("input", FluxTextInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
("denoise", FluxDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxVaeEncoderDynamicStep()),
|
||||
("input", FluxImg2ImgInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
|
||||
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
|
||||
("prepare_rope_inputs", FluxRoPEInputsStep()),
|
||||
("denoise", FluxDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
FLUX_KONTEXT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
|
||||
("input", FluxKontextInputStep()),
|
||||
("prepare_latents", FluxPrepareLatentsStep()),
|
||||
("set_timesteps", FluxSetTimestepsStep()),
|
||||
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
|
||||
("denoise", FluxKontextDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"auto_kontext": AUTO_BLOCKS_KONTEXT,
|
||||
"kontext": FLUX_KONTEXT_BLOCKS,
|
||||
}
|
||||
586
src/diffusers/modular_pipelines/flux/modular_blocks_flux.py
Normal file
586
src/diffusers/modular_pipelines/flux/modular_blocks_flux.py
Normal file
@@ -0,0 +1,586 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
FluxImg2ImgPrepareLatentsStep,
|
||||
FluxImg2ImgSetTimestepsStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxRoPEInputsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxDenoiseStep
|
||||
from .encoders import (
|
||||
FluxProcessImagesInputStep,
|
||||
FluxTextEncoderStep,
|
||||
FluxVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
FluxAdditionalInputsStep,
|
||||
FluxTextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# vae encoder (run before before_denoise)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that preprocess andencode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
|
||||
|
||||
Inputs:
|
||||
resized_image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
processed_image (`None`):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = [FluxProcessImagesInputStep(), FluxVaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
This is an auto pipeline block that works for img2img tasks.
|
||||
- `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided. - if `image` is not provided,
|
||||
step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
|
||||
|
||||
Inputs:
|
||||
resized_image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
processed_image (`None`):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [FluxImg2ImgVaeEncoderStep]
|
||||
block_names = ["img2img"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for img2img tasks.\n"
|
||||
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
# auto_docstring
|
||||
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepares the inputs for the denoise step in text-to-image generation.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()]
|
||||
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
|
||||
|
||||
|
||||
# before_denoise: img2img
|
||||
# auto_docstring
|
||||
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs for the denoise step for img2img task.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`):
|
||||
The image latents to use for the denoising process. Can be generated in vae encoder and packed in input
|
||||
step.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [
|
||||
FluxPrepareLatentsStep(),
|
||||
FluxImg2ImgSetTimestepsStep(),
|
||||
FluxImg2ImgPrepareLatentsStep(),
|
||||
FluxRoPEInputsStep(),
|
||||
]
|
||||
block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents", "prepare_rope_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepare the inputs for the denoise step for img2img task."
|
||||
|
||||
|
||||
# before_denoise: all task (text2img, img2img)
|
||||
# auto_docstring
|
||||
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs for the denoise step.
|
||||
This is an auto pipeline block that works for text2image.
|
||||
- `FluxBeforeDenoiseStep` (text2image) is used.
|
||||
- `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`):
|
||||
TODO: Add description.
|
||||
width (`int`):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
The image latents to use for the denoising process. Can be generated in vae encoder and packed in input
|
||||
step.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
initial_noise (`Tensor`):
|
||||
The initial random noised used for inpainting denoising.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# inputs: text2image/img2img
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the img2img denoising step. It:
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation
|
||||
pooled_prompt_embeds (`Tensor`):
|
||||
pooled text embeddings used to guide the image generation
|
||||
image_height (`int`):
|
||||
The height of the image latents
|
||||
image_width (`int`):
|
||||
The width of the image latents
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [FluxTextInputStep(), FluxAdditionalInputsStep()]
|
||||
block_names = ["text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxAutoInputStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size,
|
||||
and patchified.
|
||||
This is an auto pipeline block that works for text2image/img2img tasks.
|
||||
- `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.
|
||||
- `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation
|
||||
pooled_prompt_embeds (`Tensor`):
|
||||
pooled text embeddings used to guide the image generation
|
||||
image_height (`int`):
|
||||
The height of the image latents
|
||||
image_width (`int`):
|
||||
The width of the image latents
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
|
||||
block_names = ["img2img", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core step that performs the denoising process for Flux.
|
||||
This step supports text-to-image and image-to-image tasks for Flux:
|
||||
- for image-to-image generation, you need to provide `image_latents`
|
||||
- for text-to-image generation, all you need to provide is prompt embeddings.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process for Flux.\n"
|
||||
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
|
||||
+ " - for image-to-image generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# Auto blocks (text2image and img2img)
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxAutoVaeEncoderStep()),
|
||||
("denoise", FluxCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image and image-to-image using Flux.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2
|
||||
(`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
prompt_2 (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
resized_image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "flux"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for text-to-image and image-to-image using Flux."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("images")]
|
||||
@@ -0,0 +1,585 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
FluxKontextRoPEInputsStep,
|
||||
FluxPrepareLatentsStep,
|
||||
FluxRoPEInputsStep,
|
||||
FluxSetTimestepsStep,
|
||||
)
|
||||
from .decoders import FluxDecodeStep
|
||||
from .denoise import FluxKontextDenoiseStep
|
||||
from .encoders import (
|
||||
FluxKontextProcessImagesInputStep,
|
||||
FluxTextEncoderStep,
|
||||
FluxVaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
FluxKontextAdditionalInputsStep,
|
||||
FluxKontextSetResolutionStep,
|
||||
FluxTextInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# Flux Kontext vae encoder (run before before_denoise)
|
||||
# auto_docstring
|
||||
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that preprocess andencode the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
_auto_resize (`bool`, *optional*, defaults to True):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
processed_image (`None`):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = [FluxKontextProcessImagesInputStep(), FluxVaeEncoderStep(sample_mode="argmax")]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae encoder step that encode the image inputs into their latent representations.
|
||||
This is an auto pipeline block that works for image-conditioned tasks.
|
||||
- `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided. - if `image` is not
|
||||
provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
_auto_resize (`bool`, *optional*, defaults to True):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
processed_image (`None`):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`):
|
||||
The latents representing the reference image
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = [FluxKontextVaeEncoderStep]
|
||||
block_names = ["image_conditioned"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Vae encoder step that encode the image inputs into their latent representations.\n"
|
||||
+ "This is an auto pipeline block that works for image-conditioned tasks.\n"
|
||||
+ " - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided."
|
||||
+ " - if `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
# before_denoise: text2img
|
||||
# auto_docstring
|
||||
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepares the inputs for the denoise step for Flux Kontext
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()]
|
||||
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Before denoise step that prepares the inputs for the denoise step for Flux Kontext\n"
|
||||
"for text-to-image tasks."
|
||||
|
||||
|
||||
# before_denoise: image-conditioned
|
||||
# auto_docstring
|
||||
class FluxKontextImageConditionedBeforeDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs for the denoise step for Flux Kontext
|
||||
for image-conditioned tasks.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
image_height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxKontextRoPEInputsStep()]
|
||||
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step for Flux Kontext\n"
|
||||
"for image-conditioned tasks."
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Before denoise step that prepare the inputs for the denoise step.
|
||||
This is an auto pipeline block that works for text2image.
|
||||
- `FluxKontextBeforeDenoiseStep` (text2image) is used.
|
||||
- `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is
|
||||
provided.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
|
||||
Can be generated in input step.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
image_height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
timesteps (`Tensor`):
|
||||
The timesteps to use for inference
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps to perform at inference time
|
||||
guidance (`Tensor`):
|
||||
Optional guidance to be used.
|
||||
txt_ids (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation.
|
||||
img_ids (`list`):
|
||||
The sequence lengths of the image latents, used for RoPE calculation.
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = [FluxKontextImageConditionedBeforeDenoiseStep, FluxKontextBeforeDenoiseStep]
|
||||
block_names = ["image_conditioned", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Before denoise step that prepare the inputs for the denoise step.\n"
|
||||
+ "This is an auto pipeline block that works for text2image.\n"
|
||||
+ " - `FluxKontextBeforeDenoiseStep` (text2image) is used.\n"
|
||||
+ " - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is provided.\n"
|
||||
)
|
||||
|
||||
|
||||
# inputs: Flux Kontext
|
||||
# auto_docstring
|
||||
class FluxKontextInputStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Input step that prepares the inputs for the both text2img and img2img denoising step. It:
|
||||
- make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).
|
||||
- update height/width based `image_latents`, patchify `image_latents`.
|
||||
|
||||
Inputs:
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_area (`int`, *optional*, defaults to 1048576):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
The height of the initial noisy latents
|
||||
width (`int`):
|
||||
The width of the initial noisy latents
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation
|
||||
pooled_prompt_embeds (`Tensor`):
|
||||
pooled text embeddings used to guide the image generation
|
||||
image_height (`int`):
|
||||
The height of the image latents
|
||||
image_width (`int`):
|
||||
The width of the image latents
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextSetResolutionStep(), FluxTextInputStep(), FluxKontextAdditionalInputsStep()]
|
||||
block_names = ["set_resolution", "text_inputs", "additional_inputs"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
|
||||
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
|
||||
" - update height/width based `image_latents`, patchify `image_latents`."
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxKontextAutoInputStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size,
|
||||
and patchified.
|
||||
This is an auto pipeline block that works for text2image/img2img tasks.
|
||||
- `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.
|
||||
- `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present.
|
||||
|
||||
Inputs:
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_area (`int`, *optional*, defaults to 1048576):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
height (`int`):
|
||||
The height of the initial noisy latents
|
||||
width (`int`):
|
||||
The width of the initial noisy latents
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `prompt_embeds`)
|
||||
prompt_embeds (`Tensor`):
|
||||
text embeddings used to guide the image generation
|
||||
pooled_prompt_embeds (`Tensor`):
|
||||
pooled text embeddings used to guide the image generation
|
||||
image_height (`int`):
|
||||
The height of the image latents
|
||||
image_width (`int`):
|
||||
The width of the image latents
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextInputStep, FluxTextInputStep]
|
||||
block_names = ["image_conditioned", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
|
||||
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
|
||||
+ " - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.\n"
|
||||
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core step that performs the denoising process for Flux Kontext.
|
||||
This step supports text-to-image and image-conditioned tasks for Flux Kontext:
|
||||
- for image-conditioned generation, you need to provide `image_latents`
|
||||
- for text-to-image generation, all you need to provide is prompt embeddings.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_area (`int`, *optional*, defaults to 1048576):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
pooled_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep]
|
||||
block_names = ["input", "before_denoise", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core step that performs the denoising process for Flux Kontext.\n"
|
||||
+ "This step supports text-to-image and image-conditioned tasks for Flux Kontext:\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image_latents`\n"
|
||||
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
AUTO_BLOCKS_KONTEXT = InsertableDict(
|
||||
[
|
||||
("text_encoder", FluxTextEncoderStep()),
|
||||
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
|
||||
("denoise", FluxKontextCoreDenoiseStep()),
|
||||
("decode", FluxDecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class FluxKontextAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline for image-to-image using Flux Kontext.
|
||||
|
||||
Supported workflows:
|
||||
- `image_conditioned`: requires `image`, `prompt`
|
||||
- `text2image`: requires `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2
|
||||
(`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
prompt_2 (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
_auto_resize (`bool`, *optional*, defaults to True):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_area (`int`, *optional*, defaults to 1048576):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 3.5):
|
||||
TODO: Add description.
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "flux-kontext"
|
||||
|
||||
block_classes = AUTO_BLOCKS_KONTEXT.values()
|
||||
block_names = AUTO_BLOCKS_KONTEXT.keys()
|
||||
_workflow_map = {
|
||||
"image_conditioned": {"image": True, "prompt": True},
|
||||
"text2image": {"prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Modular pipeline for image-to-image using Flux Kontext."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("images")]
|
||||
@@ -21,44 +21,14 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = [
|
||||
"Flux2TextEncoderStep",
|
||||
"Flux2RemoteTextEncoderStep",
|
||||
"Flux2VaeEncoderStep",
|
||||
]
|
||||
_import_structure["before_denoise"] = [
|
||||
"Flux2SetTimestepsStep",
|
||||
"Flux2PrepareLatentsStep",
|
||||
"Flux2RoPEInputsStep",
|
||||
"Flux2PrepareImageLatentsStep",
|
||||
]
|
||||
_import_structure["denoise"] = [
|
||||
"Flux2LoopDenoiser",
|
||||
"Flux2LoopAfterDenoiser",
|
||||
"Flux2DenoiseLoopWrapper",
|
||||
"Flux2DenoiseStep",
|
||||
]
|
||||
_import_structure["decoders"] = ["Flux2DecodeStep"]
|
||||
_import_structure["inputs"] = [
|
||||
"Flux2ProcessImagesInputStep",
|
||||
"Flux2TextInputStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"REMOTE_AUTO_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"IMAGE_CONDITIONED_BLOCKS",
|
||||
"Flux2AutoBlocks",
|
||||
"Flux2AutoVaeEncoderStep",
|
||||
"Flux2CoreDenoiseStep",
|
||||
"Flux2VaeEncoderSequentialStep",
|
||||
]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["encoders"] = ["Flux2RemoteTextEncoderStep"]
|
||||
_import_structure["modular_blocks_flux2"] = ["Flux2AutoBlocks"]
|
||||
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks"]
|
||||
_import_structure["modular_blocks_flux2_klein_base"] = ["Flux2KleinBaseAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"Flux2ModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2KleinBaseModularPipeline",
|
||||
"Flux2KleinModularPipeline",
|
||||
"Flux2ModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -68,43 +38,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .before_denoise import (
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep
|
||||
from .denoise import (
|
||||
Flux2DenoiseLoopWrapper,
|
||||
Flux2DenoiseStep,
|
||||
Flux2LoopAfterDenoiser,
|
||||
Flux2LoopDenoiser,
|
||||
)
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
from .modular_blocks_flux2 import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
IMAGE_CONDITIONED_BLOCKS,
|
||||
REMOTE_AUTO_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
Flux2AutoBlocks,
|
||||
Flux2AutoVaeEncoderStep,
|
||||
Flux2CoreDenoiseStep,
|
||||
Flux2VaeEncoderSequentialStep,
|
||||
)
|
||||
from .modular_blocks_flux2_klein import (
|
||||
Flux2KleinAutoBlocks,
|
||||
Flux2KleinBaseAutoBlocks,
|
||||
)
|
||||
from .encoders import Flux2RemoteTextEncoderStep
|
||||
from .modular_blocks_flux2 import Flux2AutoBlocks
|
||||
from .modular_blocks_flux2_klein import Flux2KleinAutoBlocks
|
||||
from .modular_blocks_flux2_klein_base import Flux2KleinBaseAutoBlocks
|
||||
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -12,10 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
@@ -30,7 +26,6 @@ from .before_denoise import (
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2DenoiseStep
|
||||
from .encoders import (
|
||||
Flux2RemoteTextEncoderStep,
|
||||
Flux2TextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
@@ -43,26 +38,69 @@ from .inputs import (
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
Flux2VaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning.
|
||||
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2VaeEncoderBlocks.values()
|
||||
block_names = Flux2VaeEncoderBlocks.keys()
|
||||
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that encodes the image inputs into their latent representations.
|
||||
This is an auto pipeline block that works for image conditioning tasks.
|
||||
- `Flux2VaeEncoderSequentialStep` is used when `image` is provided.
|
||||
- If `image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
block_classes = [Flux2VaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
@@ -80,7 +118,6 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
@@ -91,7 +128,47 @@ Flux2CoreDenoiseBlocks = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-dev.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 4.0):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2CoreDenoiseBlocks.values()
|
||||
@@ -99,94 +176,18 @@ class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-dev.\n"
|
||||
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
return "Core denoise step that performs the denoising process for Flux2-dev."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
Flux2ImageConditionedCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
REMOTE_AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2RemoteTextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2CoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
|
||||
"- For text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_guidance", Flux2PrepareGuidanceStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("text_input", Flux2TextInputStep()),
|
||||
("preprocess_images", Flux2ProcessImagesInputStep()),
|
||||
("vae_encoder", Flux2VaeEncoderStep()),
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
@@ -194,13 +195,162 @@ IMAGE_CONDITIONED_BLOCKS = InsertableDict(
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2DenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2ImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-dev with image conditioning.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 4.0):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = Flux2ImageConditionedCoreDenoiseBlocks.values()
|
||||
block_names = Flux2ImageConditionedCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-dev with image conditioning."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
class Flux2AutoCoreDenoiseStep(AutoPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = [Flux2ImageConditionedCoreDenoiseStep, Flux2CoreDenoiseStep]
|
||||
block_names = ["image_conditioned", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto core denoise step that performs the denoising process for Flux2-dev."
|
||||
"This is an auto pipeline block that works for text-to-image and image-conditioned generation."
|
||||
" - `Flux2CoreDenoiseStep` is used for text-to-image generation.\n"
|
||||
" - `Flux2ImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
|
||||
)
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", Flux2TextEncoderStep()),
|
||||
("vae_encoder", Flux2AutoVaeEncoderStep()),
|
||||
("denoise", Flux2AutoCoreDenoiseStep()),
|
||||
("decode", Flux2DecodeStep()),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
"remote": REMOTE_AUTO_BLOCKS,
|
||||
}
|
||||
|
||||
# auto_docstring
|
||||
class Flux2AutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image_conditioned`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Mistral3ForConditionalGeneration`) tokenizer (`AutoProcessor`) image_processor
|
||||
(`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer
|
||||
(`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
text_encoder_out_layers (`tuple`, *optional*, defaults to (10, 20, 30)):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
timesteps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
guidance_scale (`None`, *optional*, defaults to 4.0):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image_conditioned": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -12,30 +12,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2RoPEInputsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
|
||||
from .denoise import Flux2KleinDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2KleinTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
Flux2TextInputStep,
|
||||
)
|
||||
@@ -47,26 +40,72 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
Flux2KleinVaeEncoderBlocks = InsertableDict(
|
||||
[
|
||||
("preprocess", Flux2ProcessImagesInputStep()),
|
||||
("encode", Flux2VaeEncoderStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2"
|
||||
"""
|
||||
VAE encoder step that preprocesses and encodes the image inputs into their latent representations.
|
||||
|
||||
block_classes = Flux2KleinVaeEncoderBlocks.values()
|
||||
block_names = Flux2KleinVaeEncoderBlocks.keys()
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that encodes the image inputs into their latent representations.
|
||||
This is an auto pipeline block that works for image conditioning tasks.
|
||||
- `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.
|
||||
- If `image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = [Flux2KleinVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
@@ -86,6 +125,74 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
###
|
||||
|
||||
Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2RoPEInputsStep()),
|
||||
("denoise", Flux2KleinDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image
|
||||
generation.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image generation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinImageConditionedCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2TextInputStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
@@ -98,135 +205,196 @@ Flux2KleinCoreDenoiseBlocks = InsertableDict(
|
||||
)
|
||||
|
||||
|
||||
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# auto_docstring
|
||||
class Flux2KleinImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
|
||||
block_classes = Flux2KleinCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinCoreDenoiseBlocks.keys()
|
||||
block_classes = Flux2KleinImageConditionedCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinImageConditionedCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
|
||||
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
)
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
# auto_docstring
|
||||
class Flux2KleinAutoCoreDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Auto core denoise step that performs the denoising process for Flux2-Klein.
|
||||
This is an auto pipeline block that works for text-to-image and image-conditioned generation.
|
||||
- `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.
|
||||
- `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
timesteps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
block_classes = [Flux2KleinImageConditionedCoreDenoiseStep, Flux2KleinCoreDenoiseStep]
|
||||
block_names = ["image_conditioned", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
|
||||
return (
|
||||
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
|
||||
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
|
||||
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
|
||||
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
|
||||
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
|
||||
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
|
||||
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
|
||||
"Auto core denoise step that performs the denoising process for Flux2-Klein.\n"
|
||||
"This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n"
|
||||
" - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.\n"
|
||||
" - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
description="The latents from the denoising step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image_conditioned`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) image_processor (`Flux2ImageProcessor`)
|
||||
vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer
|
||||
(`Flux2Transformer2DModel`)
|
||||
|
||||
Configs:
|
||||
is_distilled (default: True)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
timesteps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinCoreDenoiseStep(),
|
||||
Flux2KleinAutoCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image_conditioned": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
|
||||
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
|
||||
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
|
||||
)
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam(
|
||||
name="images",
|
||||
type_hint=List[PIL.Image.Image],
|
||||
description="The images from the decoding step.",
|
||||
)
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
|
||||
@@ -0,0 +1,413 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict, OutputParam
|
||||
from .before_denoise import (
|
||||
Flux2KleinBaseRoPEInputsStep,
|
||||
Flux2PrepareImageLatentsStep,
|
||||
Flux2PrepareLatentsStep,
|
||||
Flux2SetTimestepsStep,
|
||||
)
|
||||
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
|
||||
from .denoise import Flux2KleinBaseDenoiseStep
|
||||
from .encoders import (
|
||||
Flux2KleinBaseTextEncoderStep,
|
||||
Flux2VaeEncoderStep,
|
||||
)
|
||||
from .inputs import (
|
||||
Flux2KleinBaseTextInputStep,
|
||||
Flux2ProcessImagesInputStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
################
|
||||
# VAE encoder
|
||||
################
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseVaeEncoderSequentialStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that preprocesses and encodes the image inputs into their latent representations.
|
||||
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
model_name = "flux2"
|
||||
|
||||
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
|
||||
block_names = ["preprocess", "encode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
VAE encoder step that encodes the image inputs into their latent representations.
|
||||
This is an auto pipeline block that works for image conditioning tasks.
|
||||
- `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.
|
||||
- If `image` is not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
|
||||
|
||||
Inputs:
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
condition_images (`list`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`):
|
||||
List of latent representations for each reference image
|
||||
"""
|
||||
|
||||
block_classes = [Flux2KleinBaseVaeEncoderSequentialStep]
|
||||
block_names = ["img_conditioning"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"VAE encoder step that encodes the image inputs into their latent representations.\n"
|
||||
"This is an auto pipeline block that works for image conditioning tasks.\n"
|
||||
" - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.\n"
|
||||
" - If `image` is not provided, step will be skipped."
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Core denoise
|
||||
###
|
||||
|
||||
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Configs:
|
||||
is_distilled (default: False)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
Flux2KleinBaseImageConditionedCoreDenoiseBlocks = InsertableDict(
|
||||
[
|
||||
("input", Flux2KleinBaseTextInputStep()),
|
||||
("prepare_latents", Flux2PrepareLatentsStep()),
|
||||
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
|
||||
("set_timesteps", Flux2SetTimestepsStep()),
|
||||
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
|
||||
("denoise", Flux2KleinBaseDenoiseStep()),
|
||||
("after_denoise", Flux2UnpackLatentsStep()),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Configs:
|
||||
is_distilled (default: False)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
block_classes = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.values()
|
||||
block_names = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.keys()
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("latents"),
|
||||
]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseAutoCoreDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Auto core denoise step that performs the denoising process for Flux2-Klein (base model).
|
||||
This is an auto pipeline block that works for text-to-image and image-conditioned generation.
|
||||
- `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.
|
||||
- `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.
|
||||
|
||||
Components:
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Configs:
|
||||
is_distilled (default: False)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
timesteps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [Flux2KleinBaseImageConditionedCoreDenoiseStep, Flux2KleinBaseCoreDenoiseStep]
|
||||
block_names = ["image_conditioned", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
|
||||
"This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n"
|
||||
" - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.\n"
|
||||
" - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
|
||||
)
|
||||
|
||||
|
||||
###
|
||||
### Auto blocks
|
||||
###
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image_conditioned`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) guider (`ClassifierFreeGuidance`)
|
||||
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
|
||||
|
||||
Configs:
|
||||
is_distilled (default: False)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`int`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
image_latents (`list`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
timesteps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
joint_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latent_ids (`Tensor`, *optional*):
|
||||
Position IDs for image latents. Shape: (B, img_seq_len, 4)
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "flux2-klein"
|
||||
block_classes = [
|
||||
Flux2KleinBaseTextEncoderStep(),
|
||||
Flux2KleinBaseAutoVaeEncoderStep(),
|
||||
Flux2KleinBaseAutoCoreDenoiseStep(),
|
||||
Flux2DecodeStep(),
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image_conditioned": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model)."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [
|
||||
OutputParam.template("images"),
|
||||
]
|
||||
@@ -14,6 +14,7 @@
|
||||
import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
@@ -28,10 +29,16 @@ from tqdm.auto import tqdm
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
|
||||
from ..pipelines.pipeline_loading_utils import (
|
||||
LOADABLE_CLASSES,
|
||||
_fetch_class_library_tuple,
|
||||
_unwrap_model,
|
||||
simple_get_class_obj,
|
||||
)
|
||||
from ..utils import PushToHubMixin, is_accelerate_available, logging
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from ..utils.torch_utils import is_compiled_module
|
||||
from .components_manager import ComponentsManager
|
||||
from .modular_pipeline_utils import (
|
||||
MODULAR_MODEL_CARD_TEMPLATE,
|
||||
@@ -40,8 +47,12 @@ from .modular_pipeline_utils import (
|
||||
InputParam,
|
||||
InsertableDict,
|
||||
OutputParam,
|
||||
_validate_requirements,
|
||||
combine_inputs,
|
||||
combine_outputs,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_workflow,
|
||||
generate_modular_model_card_content,
|
||||
make_doc_string,
|
||||
)
|
||||
@@ -287,6 +298,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
|
||||
config_name = "modular_config.json"
|
||||
model_name = None
|
||||
_requirements: dict[str, str] | None = None
|
||||
_workflow_map = None
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
@@ -342,6 +355,35 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def outputs(self) -> list[OutputParam]:
|
||||
return self._get_outputs()
|
||||
|
||||
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
|
||||
conditional block selection.
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
"""
|
||||
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
# currently only SequentialPipelineBlocks support workflows
|
||||
@property
|
||||
def available_workflows(self):
|
||||
"""
|
||||
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
|
||||
"""
|
||||
raise NotImplementedError(f"`available_workflows` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
def get_workflow(self, workflow_name: str):
|
||||
"""
|
||||
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
|
||||
`_workflow_map`.
|
||||
|
||||
Args:
|
||||
workflow_name: Name of the workflow to retrieve.
|
||||
"""
|
||||
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
@@ -371,6 +413,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
|
||||
)
|
||||
|
||||
if "requirements" in config and config["requirements"] is not None:
|
||||
_ = _validate_requirements(config["requirements"])
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
@@ -395,8 +440,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
|
||||
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
|
||||
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
|
||||
|
||||
self.register_to_config(auto_map=auto_map)
|
||||
|
||||
# resolve requirements
|
||||
requirements = _validate_requirements(getattr(self, "_requirements", None))
|
||||
if requirements:
|
||||
self.register_to_config(requirements=requirements)
|
||||
|
||||
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
|
||||
config = dict(self.config)
|
||||
self._internal_dict = FrozenDict(config)
|
||||
@@ -480,72 +530,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
if current_value is not param: # Using identity comparison to check if object was modified
|
||||
state.set(param_name, param, input_param.kwargs_type)
|
||||
|
||||
@staticmethod
|
||||
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
|
||||
current default value is None and new default value is not None. Warns if multiple non-None default values
|
||||
exist for the same input.
|
||||
|
||||
Args:
|
||||
named_input_lists: list of tuples containing (block_name, input_param_list) pairs
|
||||
|
||||
Returns:
|
||||
list[InputParam]: Combined list of unique InputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> InputParam
|
||||
value_sources = {} # name -> block_name
|
||||
|
||||
for block_name, inputs in named_input_lists:
|
||||
for input_param in inputs:
|
||||
if input_param.name is None and input_param.kwargs_type is not None:
|
||||
input_name = "*_" + input_param.kwargs_type
|
||||
else:
|
||||
input_name = input_param.name
|
||||
if input_name in combined_dict:
|
||||
current_param = combined_dict[input_name]
|
||||
if (
|
||||
current_param.default is not None
|
||||
and input_param.default is not None
|
||||
and current_param.default != input_param.default
|
||||
):
|
||||
warnings.warn(
|
||||
f"Multiple different default values found for input '{input_name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||
)
|
||||
if current_param.default is None and input_param.default is not None:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
else:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@staticmethod
|
||||
def combine_outputs(*named_output_lists: list[tuple[str, list[OutputParam]]]) -> list[OutputParam]:
|
||||
"""
|
||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
||||
occurrence of each output name.
|
||||
|
||||
Args:
|
||||
named_output_lists: list of tuples containing (block_name, output_param_list) pairs
|
||||
|
||||
Returns:
|
||||
list[OutputParam]: Combined list of unique OutputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> OutputParam
|
||||
|
||||
for block_name, outputs in named_output_lists:
|
||||
for output_param in outputs:
|
||||
if (output_param.name not in combined_dict) or (
|
||||
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
||||
):
|
||||
combined_dict[output_param.name] = output_param
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
@property
|
||||
def input_names(self) -> list[str]:
|
||||
return [input_param.name for input_param in self.inputs if input_param.name is not None]
|
||||
@@ -577,7 +561,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
|
||||
`select_block` method to define the logic for selecting the block.
|
||||
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
|
||||
on the presence or absence of inputs (i.e., whether they are `None` or not)
|
||||
|
||||
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipeline blocks (such as loading or saving etc.)
|
||||
@@ -585,15 +570,20 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
> [!WARNING] > This is an experimental feature and is likely to change in the future.
|
||||
|
||||
Attributes:
|
||||
block_classes: List of block classes to be used
|
||||
block_names: List of prefixes for each block
|
||||
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
|
||||
block_classes: List of block classes to be used. Must have the same length as `block_names`.
|
||||
block_names: List of names for each block. Must have the same length as `block_classes`.
|
||||
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
|
||||
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
|
||||
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
|
||||
element specifies the trigger input for the corresponding block.
|
||||
default_block_name: Name of the default block to run when no trigger inputs match.
|
||||
If None, this block can be skipped entirely when no trigger inputs are provided.
|
||||
"""
|
||||
|
||||
block_classes = []
|
||||
block_names = []
|
||||
block_trigger_inputs = []
|
||||
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
|
||||
default_block_name = None
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
@@ -657,7 +647,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def inputs(self) -> list[tuple[str, Any]]:
|
||||
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
|
||||
combined_inputs = self.combine_inputs(*named_inputs)
|
||||
combined_inputs = combine_inputs(*named_inputs)
|
||||
# mark Required inputs only if that input is required by all the blocks
|
||||
for input_param in combined_inputs:
|
||||
if input_param.name in self.required_inputs:
|
||||
@@ -669,15 +659,25 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@property
|
||||
def outputs(self) -> list[str]:
|
||||
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
# used for `__repr__`
|
||||
def _get_trigger_inputs(self) -> set:
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in this block and nested blocks.
|
||||
@@ -706,16 +706,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return all_triggers
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
"""All trigger inputs including from nested blocks."""
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def select_block(self, **kwargs) -> str | None:
|
||||
"""
|
||||
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
|
||||
for selecting the block.
|
||||
|
||||
Note: When trigger inputs include intermediate outputs from earlier blocks, the selection logic should only
|
||||
depend on the presence or absence of the input (i.e., whether it is None or not), not on its actual value. This
|
||||
is because `get_execution_blocks()` resolves conditions statically by propagating intermediate output names
|
||||
without their runtime values.
|
||||
|
||||
Args:
|
||||
**kwargs: Trigger input names and their values from the state.
|
||||
|
||||
@@ -750,6 +750,39 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
logger.error(error_msg)
|
||||
raise
|
||||
|
||||
def get_execution_blocks(self, **kwargs) -> ModularPipelineBlocks | None:
|
||||
"""
|
||||
Get the block(s) that would execute given the inputs.
|
||||
|
||||
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
|
||||
- A leaf block (no sub_blocks or LoopSequentialPipelineBlocks) → returns single `ModularPipelineBlocks`
|
||||
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
|
||||
a `SequentialPipelineBlocks` containing the resolved execution blocks
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
|
||||
Returns:
|
||||
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
|
||||
- `None`: If this block would be skipped (no trigger matched and no default)
|
||||
"""
|
||||
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
|
||||
block_name = self.select_block(**trigger_kwargs)
|
||||
|
||||
if block_name is None:
|
||||
block_name = self.default_block_name
|
||||
|
||||
if block_name is None:
|
||||
return None
|
||||
|
||||
block = self.sub_blocks[block_name]
|
||||
|
||||
# Recursively resolve until we hit a leaf block
|
||||
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
|
||||
return block.get_execution_blocks(**kwargs)
|
||||
|
||||
return block
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
base_class = self.__class__.__bases__[0].__name__
|
||||
@@ -757,11 +790,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||
)
|
||||
|
||||
if self.trigger_inputs:
|
||||
if self._get_trigger_inputs():
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
|
||||
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
# Format description with proper indentation
|
||||
@@ -828,24 +861,56 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
class AutoPipelineBlocks(ConditionalPipelineBlocks):
|
||||
"""
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
|
||||
|
||||
This is a specialized version of `ConditionalPipelineBlocks` where:
|
||||
- Each block has one corresponding trigger input (1:1 mapping)
|
||||
- Block selection is automatic: the first block whose trigger input is present gets selected
|
||||
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
|
||||
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
|
||||
inputs are present
|
||||
|
||||
Attributes:
|
||||
block_classes:
|
||||
List of block classes to be used. Must have the same length as `block_names` and
|
||||
`block_trigger_inputs`.
|
||||
block_names:
|
||||
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
|
||||
block_trigger_inputs:
|
||||
List of input names where each element specifies the trigger input for the corresponding block. Use
|
||||
`None` to mark the default block.
|
||||
|
||||
Example:
|
||||
```python
|
||||
class MyAutoBlock(AutoPipelineBlocks):
|
||||
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
|
||||
block_names = ["inpaint", "img2img", "text2img"]
|
||||
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
|
||||
```
|
||||
|
||||
With this definition:
|
||||
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
|
||||
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
|
||||
- Otherwise, "text2img" block runs (default, trigger is `None`)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
if self.default_block_name is not None:
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
|
||||
f"Use `None` in `block_trigger_inputs` to specify the default block."
|
||||
)
|
||||
|
||||
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
|
||||
raise ValueError(
|
||||
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
|
||||
)
|
||||
|
||||
@property
|
||||
def default_block_name(self) -> str | None:
|
||||
"""Derive default_block_name from block_trigger_inputs (None entry)."""
|
||||
if None in self.block_trigger_inputs:
|
||||
idx = self.block_trigger_inputs.index(None)
|
||||
return self.block_names[idx]
|
||||
return None
|
||||
self.default_block_name = self.block_names[idx]
|
||||
|
||||
def select_block(self, **kwargs) -> str | None:
|
||||
"""Select block based on which trigger input is present (not None)."""
|
||||
@@ -899,6 +964,29 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
expected_configs.append(config)
|
||||
return expected_configs
|
||||
|
||||
@property
|
||||
def available_workflows(self):
|
||||
if self._workflow_map is None:
|
||||
raise NotImplementedError(
|
||||
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
return list(self._workflow_map.keys())
|
||||
|
||||
def get_workflow(self, workflow_name: str):
|
||||
if self._workflow_map is None:
|
||||
raise NotImplementedError(
|
||||
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
|
||||
)
|
||||
|
||||
if workflow_name not in self._workflow_map:
|
||||
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
|
||||
|
||||
trigger_inputs = self._workflow_map[workflow_name]
|
||||
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
|
||||
|
||||
return workflow_blocks
|
||||
|
||||
@classmethod
|
||||
def from_blocks_dict(
|
||||
cls, blocks_dict: dict[str, Any], description: str | None = None
|
||||
@@ -994,7 +1082,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
# filter out them here so they do not end up as intermediate_outputs
|
||||
if name not in inp_names:
|
||||
named_outputs.append((name, block.intermediate_outputs))
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
# YiYi TODO: I think we can remove the outputs property
|
||||
@@ -1018,6 +1106,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
raise
|
||||
return pipeline, state
|
||||
|
||||
# used for `__repr__`
|
||||
def _get_trigger_inputs(self):
|
||||
"""
|
||||
Returns a set of all unique trigger input values found in the blocks.
|
||||
@@ -1041,89 +1130,56 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
return fn_recursive_get_trigger(self.sub_blocks)
|
||||
|
||||
@property
|
||||
def trigger_inputs(self):
|
||||
return self._get_trigger_inputs()
|
||||
|
||||
def _traverse_trigger_blocks(self, active_inputs):
|
||||
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
|
||||
"""
|
||||
Traverse blocks and select which ones would run given the active inputs.
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
As the traversal walks through sequential blocks, intermediate outputs from resolved blocks are added to the
|
||||
active inputs. This means conditional blocks that depend on intermediates (e.g., "run img2img if image_latents
|
||||
is present") will resolve correctly, as long as the condition is based on presence/absence (None or not None),
|
||||
not on the actual value.
|
||||
|
||||
|
||||
Args:
|
||||
active_inputs: Dict of input names to values that are "present"
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
|
||||
Returns:
|
||||
OrderedDict of block_name -> block that would execute
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
"""
|
||||
# Copy kwargs so we can add outputs as we traverse
|
||||
active_inputs = dict(kwargs)
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_inputs):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
|
||||
if isinstance(block, ConditionalPipelineBlocks):
|
||||
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
|
||||
selected_block_name = block.select_block(**trigger_kwargs)
|
||||
|
||||
if selected_block_name is None:
|
||||
selected_block_name = block.default_block_name
|
||||
|
||||
if selected_block_name is None:
|
||||
block = block.get_execution_blocks(**active_inputs)
|
||||
if block is None:
|
||||
return result_blocks
|
||||
|
||||
selected_block = block.sub_blocks[selected_block_name]
|
||||
|
||||
if selected_block.sub_blocks:
|
||||
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
|
||||
else:
|
||||
result_blocks[block_name] = selected_block
|
||||
if hasattr(selected_block, "outputs"):
|
||||
for out in selected_block.outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
if block.sub_blocks:
|
||||
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
|
||||
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
|
||||
for sub_block_name, sub_block in block.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
|
||||
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
|
||||
result_blocks.update(nested_blocks)
|
||||
else:
|
||||
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
|
||||
result_blocks[block_name] = block
|
||||
if hasattr(block, "outputs"):
|
||||
for out in block.outputs:
|
||||
# Add outputs to active_inputs so subsequent blocks can use them as triggers
|
||||
if hasattr(block, "intermediate_outputs"):
|
||||
for out in block.intermediate_outputs:
|
||||
active_inputs[out.name] = True
|
||||
|
||||
return result_blocks
|
||||
|
||||
all_blocks = OrderedDict()
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(blocks_to_update)
|
||||
return all_blocks
|
||||
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
|
||||
all_blocks.update(nested_blocks)
|
||||
|
||||
def get_execution_blocks(self, **kwargs):
|
||||
"""
|
||||
Get the blocks that would execute given the specified inputs.
|
||||
|
||||
Args:
|
||||
**kwargs: Input names and values. Only trigger inputs affect block selection.
|
||||
Pass any inputs that would be non-None at runtime.
|
||||
|
||||
Returns:
|
||||
SequentialPipelineBlocks containing only the blocks that would execute
|
||||
|
||||
Example:
|
||||
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
|
||||
image=image)
|
||||
|
||||
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
|
||||
"""
|
||||
# Filter out None values
|
||||
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
|
||||
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
|
||||
|
||||
def __repr__(self):
|
||||
class_name = self.__class__.__name__
|
||||
@@ -1132,18 +1188,23 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
|
||||
)
|
||||
|
||||
if self.trigger_inputs:
|
||||
if self._workflow_map is None and self._get_trigger_inputs():
|
||||
header += "\n"
|
||||
header += " " + "=" * 100 + "\n"
|
||||
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
|
||||
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
|
||||
# Get first trigger input as example
|
||||
example_input = next(t for t in self.trigger_inputs if t is not None)
|
||||
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
|
||||
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
|
||||
header += " " + "=" * 100 + "\n\n"
|
||||
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
|
||||
# Format description with proper indentation
|
||||
desc_lines = self.description.split("\n")
|
||||
desc_lines = description.split("\n")
|
||||
desc = []
|
||||
# First line with "Description:" label
|
||||
desc.append(f" Description: {desc_lines[0]}")
|
||||
@@ -1191,15 +1252,28 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
description = self.description
|
||||
if self._workflow_map is not None:
|
||||
workflow_str = format_workflow(self._workflow_map)
|
||||
description = f"{self.description}\n\n{workflow_str}"
|
||||
|
||||
return make_doc_string(
|
||||
self.inputs,
|
||||
self.outputs,
|
||||
self.description,
|
||||
description=description,
|
||||
class_name=self.__class__.__name__,
|
||||
expected_components=self.expected_components,
|
||||
expected_configs=self.expected_configs,
|
||||
)
|
||||
|
||||
@property
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
|
||||
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
"""
|
||||
@@ -1327,7 +1401,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
@property
|
||||
def intermediate_outputs(self) -> list[str]:
|
||||
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
|
||||
combined_outputs = self.combine_outputs(*named_outputs)
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
for output in self.loop_intermediate_outputs:
|
||||
if output.name not in {output.name for output in combined_outputs}:
|
||||
combined_outputs.append(output)
|
||||
@@ -1338,6 +1412,15 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
|
||||
def outputs(self) -> list[str]:
|
||||
return next(reversed(self.sub_blocks.values())).intermediate_outputs
|
||||
|
||||
@property
|
||||
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
|
||||
def _requirements(self) -> dict[str, str]:
|
||||
requirements = {}
|
||||
for block_name, block in self.sub_blocks.items():
|
||||
if getattr(block, "_requirements", None):
|
||||
requirements[block_name] = block._requirements
|
||||
return requirements
|
||||
|
||||
def __init__(self):
|
||||
sub_blocks = InsertableDict()
|
||||
for block_name, block in zip(self.block_names, self.block_classes):
|
||||
@@ -1593,7 +1676,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
blocks_class_name = self.default_blocks_name
|
||||
if blocks_class_name is not None:
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name, None)
|
||||
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
|
||||
# fall back to default_blocks_name
|
||||
if blocks_class is None or not blocks_class.block_classes:
|
||||
blocks_class_name = self.default_blocks_name
|
||||
blocks_class = getattr(diffusers_module, blocks_class_name)
|
||||
|
||||
if blocks_class is not None:
|
||||
blocks = blocks_class()
|
||||
else:
|
||||
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
|
||||
@@ -1653,6 +1743,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
|
||||
)
|
||||
|
||||
self._pretrained_model_name_or_path = pretrained_model_name_or_path
|
||||
|
||||
@property
|
||||
def default_call_parameters(self) -> dict[str, Any]:
|
||||
"""
|
||||
@@ -1779,44 +1871,136 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
)
|
||||
return pipeline
|
||||
|
||||
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: str | os.PathLike,
|
||||
safe_serialization: bool = True,
|
||||
variant: str | None = None,
|
||||
max_shard_size: int | str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save the pipeline to a directory. It does not save components, you need to save them separately.
|
||||
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
|
||||
[`~ModularPipeline.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Path to the directory where the pipeline will be saved.
|
||||
push_to_hub (`bool`, optional):
|
||||
Whether to push the pipeline to the huggingface hub.
|
||||
**kwargs: Additional arguments passed to `save_config()` method
|
||||
Directory to save the pipeline to. Will be created if it doesn't exist.
|
||||
safe_serialization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
||||
variant (`str`, *optional*):
|
||||
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
||||
max_shard_size (`int` or `str`, defaults to `None`):
|
||||
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
||||
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
||||
If expressed as an integer, the unit is bytes.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether to push the pipeline to the Hugging Face model hub after saving it.
|
||||
**kwargs: Additional keyword arguments:
|
||||
- `overwrite_modular_index` (`bool`, *optional*, defaults to `False`):
|
||||
When saving a Modular Pipeline, its components in `modular_model_index.json` may reference repos
|
||||
different from the destination repo. Setting this to `True` updates all component references in
|
||||
`modular_model_index.json` so they point to the repo specified by `repo_id`.
|
||||
- `repo_id` (`str`, *optional*):
|
||||
The repository ID to push the pipeline to. Defaults to the last component of `save_directory`.
|
||||
- `commit_message` (`str`, *optional*):
|
||||
Commit message for the push to hub operation.
|
||||
- `private` (`bool`, *optional*):
|
||||
Whether the repository should be private.
|
||||
- `create_pr` (`bool`, *optional*, defaults to `False`):
|
||||
Whether to create a pull request instead of pushing directly.
|
||||
- `token` (`str`, *optional*):
|
||||
The Hugging Face token to use for authentication.
|
||||
"""
|
||||
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
|
||||
if push_to_hub:
|
||||
commit_message = kwargs.pop("commit_message", None)
|
||||
private = kwargs.pop("private", None)
|
||||
create_pr = kwargs.pop("create_pr", False)
|
||||
token = kwargs.pop("token", None)
|
||||
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
||||
update_model_card = kwargs.pop("update_model_card", False)
|
||||
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
||||
|
||||
# Generate modular pipeline card content
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
for component_name, component_spec in self._component_specs.items():
|
||||
if component_spec.default_creation_method != "from_pretrained":
|
||||
continue
|
||||
|
||||
# Create a new empty model card and eventually tag it
|
||||
component = getattr(self, component_name, None)
|
||||
if component is None:
|
||||
continue
|
||||
|
||||
model_cls = component.__class__
|
||||
if is_compiled_module(component):
|
||||
component = _unwrap_model(component)
|
||||
model_cls = component.__class__
|
||||
|
||||
save_method_name = None
|
||||
for library_name, library_classes in LOADABLE_CLASSES.items():
|
||||
if library_name in sys.modules:
|
||||
library = importlib.import_module(library_name)
|
||||
else:
|
||||
logger.info(
|
||||
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
|
||||
)
|
||||
continue
|
||||
|
||||
for base_class, save_load_methods in library_classes.items():
|
||||
class_candidate = getattr(library, base_class, None)
|
||||
if class_candidate is not None and issubclass(model_cls, class_candidate):
|
||||
save_method_name = save_load_methods[0]
|
||||
break
|
||||
if save_method_name is not None:
|
||||
break
|
||||
|
||||
if save_method_name is None:
|
||||
logger.warning(f"self.{component_name}={component} of type {type(component)} cannot be saved.")
|
||||
continue
|
||||
|
||||
save_method = getattr(component, save_method_name)
|
||||
save_method_signature = inspect.signature(save_method)
|
||||
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
||||
save_method_accept_variant = "variant" in save_method_signature.parameters
|
||||
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
||||
|
||||
save_kwargs = {}
|
||||
if save_method_accept_safe:
|
||||
save_kwargs["safe_serialization"] = safe_serialization
|
||||
if save_method_accept_variant:
|
||||
save_kwargs["variant"] = variant
|
||||
if save_method_accept_max_shard_size and max_shard_size is not None:
|
||||
save_kwargs["max_shard_size"] = max_shard_size
|
||||
|
||||
component_save_path = os.path.join(save_directory, component_name)
|
||||
save_method(component_save_path, **save_kwargs)
|
||||
|
||||
if component_name not in self.config:
|
||||
continue
|
||||
|
||||
has_no_load_id = not hasattr(component, "_diffusers_load_id") or component._diffusers_load_id == "null"
|
||||
if overwrite_modular_index or has_no_load_id:
|
||||
library, class_name, component_spec_dict = self.config[component_name]
|
||||
component_spec_dict["pretrained_model_name_or_path"] = repo_id if push_to_hub else save_directory
|
||||
component_spec_dict["subfolder"] = component_name
|
||||
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
|
||||
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
card_content = generate_modular_model_card_content(self.blocks)
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id,
|
||||
token=token,
|
||||
is_pipeline=True,
|
||||
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
|
||||
is_modular=True,
|
||||
update_model_card=update_model_card,
|
||||
)
|
||||
model_card = populate_model_card(model_card, tags=card_content["tags"])
|
||||
|
||||
model_card.save(os.path.join(save_directory, "README.md"))
|
||||
|
||||
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
|
||||
self.save_config(save_directory=save_directory)
|
||||
|
||||
if push_to_hub:
|
||||
self._upload_folder(
|
||||
save_directory,
|
||||
repo_id,
|
||||
@@ -2067,58 +2251,30 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
|
||||
|
||||
Args:
|
||||
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
|
||||
- Component objects: Only supports components we can extract specs using
|
||||
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
|
||||
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
|
||||
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
|
||||
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
|
||||
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
|
||||
- Configuration values: Simple values to update configuration settings (e.g.,
|
||||
`requires_safety_checker=False`)
|
||||
|
||||
Raises:
|
||||
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
|
||||
- nn.Module components without a valid `_diffusers_load_id` attribute
|
||||
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
|
||||
**kwargs: Component objects or configuration values to update:
|
||||
- Component objects: Models loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()`
|
||||
are automatically tagged with loading information. ConfigMixin objects without weights (e.g.,
|
||||
schedulers, guiders) can be passed directly.
|
||||
- Configuration values: Simple values to update configuration settings
|
||||
(e.g., `requires_safety_checker=False`)
|
||||
|
||||
Examples:
|
||||
```python
|
||||
# Update multiple components at once
|
||||
# Update pre-trained model
|
||||
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
|
||||
|
||||
# Update configuration values
|
||||
pipeline.update_components(requires_safety_checker=False)
|
||||
|
||||
# Update both components and configs together
|
||||
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
|
||||
|
||||
# Update with ComponentSpec objects (from_config only)
|
||||
pipeline.update_components(
|
||||
guider=ComponentSpec(
|
||||
name="guider",
|
||||
type_hint=ClassifierFreeGuidance,
|
||||
config={"guidance_scale": 5.0},
|
||||
default_creation_method="from_config",
|
||||
)
|
||||
)
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
|
||||
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
|
||||
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
|
||||
update_components()
|
||||
- Components loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()` will have
|
||||
loading specs preserved for serialization. Custom or locally loaded components without Hub references will
|
||||
have their `modular_model_index.json` entries updated automatically during `save_pretrained()`.
|
||||
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
|
||||
"""
|
||||
|
||||
# extract component_specs_updates & config_specs_updates from `specs`
|
||||
passed_component_specs = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {
|
||||
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
|
||||
}
|
||||
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
|
||||
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
|
||||
|
||||
for name, component in passed_components.items():
|
||||
@@ -2136,13 +2292,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
new_component_spec = current_component_spec
|
||||
if hasattr(self, name) and getattr(self, name) is not None:
|
||||
logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
|
||||
elif current_component_spec.default_creation_method == "from_pretrained" and not (
|
||||
hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
|
||||
elif (
|
||||
current_component_spec.default_creation_method == "from_pretrained"
|
||||
and getattr(component, "_diffusers_load_id", None) is None
|
||||
):
|
||||
logger.warning(
|
||||
f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
|
||||
f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
|
||||
)
|
||||
new_component_spec = ComponentSpec(name=name, type_hint=type(component))
|
||||
else:
|
||||
new_component_spec = ComponentSpec.from_component(name, component)
|
||||
@@ -2157,33 +2310,14 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
if len(kwargs) > 0:
|
||||
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
|
||||
|
||||
created_components = {}
|
||||
for name, component_spec in passed_component_specs.items():
|
||||
if component_spec.default_creation_method == "from_pretrained":
|
||||
raise ValueError(
|
||||
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
|
||||
)
|
||||
created_components[name] = component_spec.create()
|
||||
current_component_spec = self._component_specs[name]
|
||||
# warn if type changed
|
||||
if current_component_spec.type_hint is not None and not isinstance(
|
||||
created_components[name], current_component_spec.type_hint
|
||||
):
|
||||
logger.info(
|
||||
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
|
||||
)
|
||||
# update _component_specs based on the user passed component_spec
|
||||
self._component_specs[name] = component_spec
|
||||
self.register_components(**passed_components, **created_components)
|
||||
self.register_components(**passed_components)
|
||||
|
||||
config_to_register = {}
|
||||
for name, new_value in passed_config_values.items():
|
||||
# e.g. requires_aesthetics_score = False
|
||||
self._config_specs[name].default = new_value
|
||||
config_to_register[name] = new_value
|
||||
self.register_to_config(**config_to_register)
|
||||
|
||||
# YiYi TODO: support map for additional from_pretrained kwargs
|
||||
def load_components(self, names: list[str] | str | None = None, **kwargs):
|
||||
"""
|
||||
Load selected components from specs.
|
||||
@@ -2234,17 +2368,49 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
|
||||
elif "default" in value:
|
||||
# check if the default is specified
|
||||
component_load_kwargs[key] = value["default"]
|
||||
# Only pass trust_remote_code to components from the same repo as the pipeline.
|
||||
# When a user passes trust_remote_code=True, they intend to trust code from the
|
||||
# pipeline's repo, not from external repos referenced in modular_model_index.json.
|
||||
trust_remote_code_stripped = False
|
||||
if (
|
||||
"trust_remote_code" in component_load_kwargs
|
||||
and self._pretrained_model_name_or_path is not None
|
||||
and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path
|
||||
):
|
||||
component_load_kwargs.pop("trust_remote_code")
|
||||
trust_remote_code_stripped = True
|
||||
|
||||
if not spec.pretrained_model_name_or_path:
|
||||
logger.info(f"Skipping component `{name}`: no pretrained model path specified.")
|
||||
continue
|
||||
|
||||
try:
|
||||
components_to_register[name] = spec.load(**component_load_kwargs)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
f"\nFailed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{traceback.format_exc()}"
|
||||
)
|
||||
tb = traceback.format_exc()
|
||||
if trust_remote_code_stripped and "trust_remote_code" in tb:
|
||||
warning_msg = (
|
||||
f"Failed to load component `{name}` from external repository "
|
||||
f"`{spec.pretrained_model_name_or_path}`.\n\n"
|
||||
f"`trust_remote_code=True` was not forwarded to `{name}` because it comes from "
|
||||
f"a different repository than the pipeline (`{self._pretrained_model_name_or_path}`). "
|
||||
f"For safety, `trust_remote_code` is only forwarded to components from the same "
|
||||
f"repository as the pipeline.\n\n"
|
||||
f"You need to load this component manually with `trust_remote_code=True` and pass it "
|
||||
f"to the pipeline via `pipe.update_components()`. For example, if it is a custom model:\n\n"
|
||||
f' {name} = AutoModel.from_pretrained("{spec.pretrained_model_name_or_path}", trust_remote_code=True)\n'
|
||||
f" pipe.update_components({name}={name})\n"
|
||||
)
|
||||
else:
|
||||
warning_msg = (
|
||||
f"Failed to create component {name}:\n"
|
||||
f"- Component spec: {spec}\n"
|
||||
f"- load() called with kwargs: {component_load_kwargs}\n"
|
||||
"If this component is not required for your workflow you can safely ignore this message.\n\n"
|
||||
"Traceback:\n"
|
||||
f"{tb}"
|
||||
)
|
||||
logger.warning(warning_msg)
|
||||
|
||||
# Register all components at once
|
||||
self.register_components(**components_to_register)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from types import UnionType
|
||||
@@ -21,10 +22,12 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
|
||||
|
||||
import PIL.Image
|
||||
import torch
|
||||
from packaging.specifiers import InvalidSpecifier, SpecifierSet
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..loaders.single_file_utils import _is_single_file_path_or_url
|
||||
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
|
||||
from ..utils.import_utils import _is_package_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -49,11 +52,7 @@ This modular pipeline is composed of the following blocks:
|
||||
|
||||
{components_description} {configs_section}
|
||||
|
||||
## Input/Output Specification
|
||||
|
||||
### Inputs {inputs_description}
|
||||
|
||||
### Outputs {outputs_description}
|
||||
{io_specification_section}
|
||||
"""
|
||||
|
||||
|
||||
@@ -310,6 +309,12 @@ class ComponentSpec:
|
||||
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
|
||||
)
|
||||
|
||||
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
|
||||
# As a result, it gets stored in `init_kwargs`, which are written to the config
|
||||
# during save. This causes JSON serialization to fail when saving the component.
|
||||
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
|
||||
kwargs.pop("torch_dtype", None)
|
||||
|
||||
if self.type_hint is None:
|
||||
try:
|
||||
from diffusers import AutoModel
|
||||
@@ -327,6 +332,12 @@ class ComponentSpec:
|
||||
else getattr(self.type_hint, "from_pretrained")
|
||||
)
|
||||
|
||||
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
|
||||
# As a result, it gets stored in `init_kwargs`, which are written to the config
|
||||
# during save. This causes JSON serialization to fail when saving the component.
|
||||
if not issubclass(self.type_hint, torch.nn.Module):
|
||||
kwargs.pop("torch_dtype", None)
|
||||
|
||||
try:
|
||||
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
|
||||
except Exception as e:
|
||||
@@ -503,6 +514,10 @@ OUTPUT_PARAM_TEMPLATES = {
|
||||
"type_hint": list[PIL.Image.Image],
|
||||
"description": "Generated images.",
|
||||
},
|
||||
"videos": {
|
||||
"type_hint": list[PIL.Image.Image],
|
||||
"description": "The generated videos.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Denoised latents.",
|
||||
@@ -794,6 +809,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
|
||||
return format_params(output_params, "Outputs", indent_level, max_line_length)
|
||||
|
||||
|
||||
def format_params_markdown(params, header="Inputs"):
|
||||
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
|
||||
|
||||
Suitable for model cards rendered on Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
params: list of InputParam or OutputParam objects to format
|
||||
header: Header text (e.g. "Inputs" or "Outputs")
|
||||
|
||||
Returns:
|
||||
A formatted markdown string, or empty string if params is empty.
|
||||
"""
|
||||
if not params:
|
||||
return ""
|
||||
|
||||
def get_type_str(type_hint):
|
||||
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
|
||||
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
|
||||
return " | ".join(type_strs)
|
||||
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
|
||||
|
||||
lines = [f"**{header}:**\n"] if header else []
|
||||
for param in params:
|
||||
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
|
||||
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
|
||||
param_str = f"- `{name}` (`{type_str}`"
|
||||
|
||||
if hasattr(param, "required") and not param.required:
|
||||
param_str += ", *optional*"
|
||||
if param.default is not None:
|
||||
param_str += f", defaults to `{param.default}`"
|
||||
param_str += ")"
|
||||
|
||||
desc = param.description if param.description else "No description provided"
|
||||
param_str += f": {desc}"
|
||||
lines.append(param_str)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
|
||||
"""Format a list of ComponentSpec objects into a readable string representation.
|
||||
|
||||
@@ -887,6 +942,30 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
|
||||
return "\n".join(formatted_configs)
|
||||
|
||||
|
||||
def format_workflow(workflow_map):
|
||||
"""Format a workflow map into a readable string representation.
|
||||
|
||||
Args:
|
||||
workflow_map: Dictionary mapping workflow names to trigger inputs
|
||||
|
||||
Returns:
|
||||
A formatted string representing all workflows
|
||||
"""
|
||||
if workflow_map is None:
|
||||
return ""
|
||||
|
||||
lines = ["Supported workflows:"]
|
||||
for workflow_name, trigger_inputs in workflow_map.items():
|
||||
required_inputs = [k for k, v in trigger_inputs.items() if v]
|
||||
if required_inputs:
|
||||
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
|
||||
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
|
||||
else:
|
||||
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def make_doc_string(
|
||||
inputs,
|
||||
outputs,
|
||||
@@ -943,6 +1022,155 @@ def make_doc_string(
|
||||
return output
|
||||
|
||||
|
||||
def _validate_requirements(reqs):
|
||||
if reqs is None:
|
||||
normalized_reqs = {}
|
||||
else:
|
||||
if not isinstance(reqs, dict):
|
||||
raise ValueError(
|
||||
"Requirements must be provided as a dictionary mapping package names to version specifiers."
|
||||
)
|
||||
normalized_reqs = _normalize_requirements(reqs)
|
||||
|
||||
if not normalized_reqs:
|
||||
return {}
|
||||
|
||||
final: dict[str, str] = {}
|
||||
for req, specified_ver in normalized_reqs.items():
|
||||
req_available, req_actual_ver = _is_package_available(req)
|
||||
if not req_available:
|
||||
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
|
||||
|
||||
if specified_ver:
|
||||
try:
|
||||
specifier = SpecifierSet(specified_ver)
|
||||
except InvalidSpecifier as err:
|
||||
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
|
||||
|
||||
if req_actual_ver == "N/A":
|
||||
logger.warning(
|
||||
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
|
||||
)
|
||||
elif not specifier.contains(req_actual_ver, prereleases=True):
|
||||
logger.warning(
|
||||
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
|
||||
)
|
||||
|
||||
final[req] = specified_ver
|
||||
|
||||
return final
|
||||
|
||||
|
||||
def _normalize_requirements(reqs):
|
||||
if not reqs:
|
||||
return {}
|
||||
|
||||
normalized: "OrderedDict[str, str]" = OrderedDict()
|
||||
|
||||
def _accumulate(mapping: dict[str, Any]):
|
||||
for pkg, spec in mapping.items():
|
||||
if isinstance(spec, dict):
|
||||
# This is recursive because blocks are composable. This way, we can merge requirements
|
||||
# from multiple blocks.
|
||||
_accumulate(spec)
|
||||
continue
|
||||
|
||||
pkg_name = str(pkg).strip()
|
||||
if not pkg_name:
|
||||
raise ValueError("Requirement package name cannot be empty.")
|
||||
|
||||
spec_str = "" if spec is None else str(spec).strip()
|
||||
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
|
||||
spec_str = f"=={spec_str}"
|
||||
|
||||
existing_spec = normalized.get(pkg_name)
|
||||
if existing_spec is not None:
|
||||
if not existing_spec and spec_str:
|
||||
normalized[pkg_name] = spec_str
|
||||
elif existing_spec and spec_str and existing_spec != spec_str:
|
||||
try:
|
||||
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
|
||||
except InvalidSpecifier:
|
||||
logger.warning(
|
||||
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
|
||||
)
|
||||
else:
|
||||
normalized[pkg_name] = str(combined_spec)
|
||||
continue
|
||||
|
||||
normalized[pkg_name] = spec_str
|
||||
|
||||
_accumulate(reqs)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
|
||||
"""
|
||||
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
|
||||
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
|
||||
same input.
|
||||
|
||||
Args:
|
||||
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[InputParam]: Combined list of unique InputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> InputParam
|
||||
value_sources = {} # name -> block_name
|
||||
|
||||
for block_name, inputs in named_input_lists:
|
||||
for input_param in inputs:
|
||||
if input_param.name is None and input_param.kwargs_type is not None:
|
||||
input_name = "*_" + input_param.kwargs_type
|
||||
else:
|
||||
input_name = input_param.name
|
||||
if input_name in combined_dict:
|
||||
current_param = combined_dict[input_name]
|
||||
if (
|
||||
current_param.default is not None
|
||||
and input_param.default is not None
|
||||
and current_param.default != input_param.default
|
||||
):
|
||||
warnings.warn(
|
||||
f"Multiple different default values found for input '{input_name}': "
|
||||
f"{current_param.default} (from block '{value_sources[input_name]}') and "
|
||||
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
|
||||
)
|
||||
if current_param.default is None and input_param.default is not None:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
else:
|
||||
combined_dict[input_name] = input_param
|
||||
value_sources[input_name] = block_name
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
def combine_outputs(*named_output_lists: list[tuple[str, list[OutputParam]]]) -> list[OutputParam]:
|
||||
"""
|
||||
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
|
||||
occurrence of each output name.
|
||||
|
||||
Args:
|
||||
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
|
||||
|
||||
Returns:
|
||||
List[OutputParam]: Combined list of unique OutputParam objects
|
||||
"""
|
||||
combined_dict = {} # name -> OutputParam
|
||||
|
||||
for block_name, outputs in named_output_lists:
|
||||
for output_param in outputs:
|
||||
if (output_param.name not in combined_dict) or (
|
||||
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
|
||||
):
|
||||
combined_dict[output_param.name] = output_param
|
||||
|
||||
return list(combined_dict.values())
|
||||
|
||||
|
||||
def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
"""
|
||||
Generate model card content for a modular pipeline.
|
||||
@@ -960,8 +1188,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
- blocks_description: Detailed architecture of blocks
|
||||
- components_description: List of required components
|
||||
- configs_section: Configuration parameters section
|
||||
- inputs_description: Input parameters specification
|
||||
- outputs_description: Output parameters specification
|
||||
- io_specification_section: Input/Output specification (per-workflow or unified)
|
||||
- trigger_inputs_section: Conditional execution information
|
||||
- tags: List of relevant tags for the model card
|
||||
"""
|
||||
@@ -980,15 +1207,6 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if block_desc:
|
||||
blocks_desc_parts.append(f" - {block_desc}")
|
||||
|
||||
# add sub-blocks if any
|
||||
if hasattr(block, "sub_blocks") and block.sub_blocks:
|
||||
for sub_name, sub_block in block.sub_blocks.items():
|
||||
sub_class = sub_block.__class__.__name__
|
||||
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
|
||||
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
|
||||
if sub_desc:
|
||||
blocks_desc_parts.append(f" - {sub_desc}")
|
||||
|
||||
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
|
||||
|
||||
components = getattr(blocks, "expected_components", [])
|
||||
@@ -1014,63 +1232,76 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
|
||||
if configs_description:
|
||||
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
|
||||
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
# Branch on whether workflows are defined
|
||||
has_workflows = getattr(blocks, "_workflow_map", None) is not None
|
||||
|
||||
# format inputs as markdown list
|
||||
inputs_parts = []
|
||||
required_inputs = [inp for inp in inputs if inp.required]
|
||||
optional_inputs = [inp for inp in inputs if not inp.required]
|
||||
if has_workflows:
|
||||
workflow_map = blocks._workflow_map
|
||||
parts = []
|
||||
|
||||
if required_inputs:
|
||||
inputs_parts.append("**Required:**\n")
|
||||
for inp in required_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
|
||||
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
|
||||
# use that as the shared output for all workflows
|
||||
blocks_outputs = blocks.outputs
|
||||
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
|
||||
shared_outputs = (
|
||||
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
|
||||
)
|
||||
|
||||
if optional_inputs:
|
||||
if required_inputs:
|
||||
inputs_parts.append("")
|
||||
inputs_parts.append("**Optional:**\n")
|
||||
for inp in optional_inputs:
|
||||
if hasattr(inp.type_hint, "__name__"):
|
||||
type_str = inp.type_hint.__name__
|
||||
elif inp.type_hint is not None:
|
||||
type_str = str(inp.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = inp.description or "No description provided"
|
||||
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
|
||||
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
|
||||
parts.append("## Workflow Input Specification\n")
|
||||
|
||||
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
|
||||
# Per-workflow details: show trigger inputs with full param descriptions
|
||||
for wf_name, trigger_inputs in workflow_map.items():
|
||||
trigger_input_names = set(trigger_inputs.keys())
|
||||
try:
|
||||
workflow_blocks = blocks.get_workflow(wf_name)
|
||||
except Exception:
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
parts.append("*Could not resolve workflow blocks.*\n")
|
||||
parts.append("</details>\n")
|
||||
continue
|
||||
|
||||
# format outputs as markdown list
|
||||
outputs_parts = []
|
||||
for out in outputs:
|
||||
if hasattr(out.type_hint, "__name__"):
|
||||
type_str = out.type_hint.__name__
|
||||
elif out.type_hint is not None:
|
||||
type_str = str(out.type_hint).replace("typing.", "")
|
||||
else:
|
||||
type_str = "Any"
|
||||
desc = out.description or "No description provided"
|
||||
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
|
||||
wf_inputs = workflow_blocks.inputs
|
||||
# Show only trigger inputs with full parameter descriptions
|
||||
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
|
||||
|
||||
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
|
||||
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
inputs_str = format_params_markdown(trigger_params, header=None)
|
||||
parts.append(inputs_str if inputs_str else "No additional inputs required.")
|
||||
parts.append("")
|
||||
|
||||
parts.append("</details>\n")
|
||||
|
||||
# Common Inputs & Outputs section (like non-workflow pipelines)
|
||||
all_inputs = blocks.inputs
|
||||
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
|
||||
|
||||
inputs_str = format_params_markdown(all_inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(all_outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
|
||||
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
|
||||
|
||||
io_specification_section = "\n".join(parts)
|
||||
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
|
||||
trigger_inputs_section = ""
|
||||
else:
|
||||
# Unified I/O section (original behavior)
|
||||
inputs = blocks.inputs
|
||||
outputs = blocks.outputs
|
||||
inputs_str = format_params_markdown(inputs, "Inputs")
|
||||
outputs_str = format_params_markdown(outputs, "Outputs")
|
||||
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
|
||||
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
|
||||
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
|
||||
|
||||
trigger_inputs_section = ""
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
|
||||
if trigger_inputs_list:
|
||||
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
|
||||
trigger_inputs_section = f"""
|
||||
### Conditional Execution
|
||||
|
||||
This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
@@ -1083,7 +1314,18 @@ This pipeline contains blocks that are selected at runtime based on inputs:
|
||||
if hasattr(blocks, "model_name") and blocks.model_name:
|
||||
tags.append(blocks.model_name)
|
||||
|
||||
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
if has_workflows:
|
||||
# Derive tags from workflow names
|
||||
workflow_names = set(blocks._workflow_map.keys())
|
||||
if any("inpainting" in wf for wf in workflow_names):
|
||||
tags.append("inpainting")
|
||||
if any("image2image" in wf for wf in workflow_names):
|
||||
tags.append("image-to-image")
|
||||
if any("controlnet" in wf for wf in workflow_names):
|
||||
tags.append("controlnet")
|
||||
if any("text2image" in wf for wf in workflow_names):
|
||||
tags.append("text-to-image")
|
||||
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
|
||||
triggers = blocks.trigger_inputs
|
||||
if any(t in triggers for t in ["mask", "mask_image"]):
|
||||
tags.append("inpainting")
|
||||
@@ -1111,8 +1353,7 @@ This pipeline uses a {block_count}-block architecture that can be customized and
|
||||
"blocks_description": blocks_description,
|
||||
"components_description": components_description,
|
||||
"configs_section": configs_section,
|
||||
"inputs_description": inputs_description,
|
||||
"outputs_description": outputs_description,
|
||||
"io_specification_section": io_specification_section,
|
||||
"trigger_inputs_section": trigger_inputs_section,
|
||||
"tags": tags,
|
||||
}
|
||||
|
||||
@@ -21,27 +21,15 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["modular_blocks_qwenimage"] = [
|
||||
"AUTO_BLOCKS",
|
||||
"QwenImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit"] = [
|
||||
"EDIT_AUTO_BLOCKS",
|
||||
"QwenImageEditAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
|
||||
"EDIT_PLUS_AUTO_BLOCKS",
|
||||
"QwenImageEditPlusAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage_layered"] = [
|
||||
"LAYERED_AUTO_BLOCKS",
|
||||
"QwenImageLayeredAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_qwenimage"] = ["QwenImageAutoBlocks"]
|
||||
_import_structure["modular_blocks_qwenimage_edit"] = ["QwenImageEditAutoBlocks"]
|
||||
_import_structure["modular_blocks_qwenimage_edit_plus"] = ["QwenImageEditPlusAutoBlocks"]
|
||||
_import_structure["modular_blocks_qwenimage_layered"] = ["QwenImageLayeredAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = [
|
||||
"QwenImageEditModularPipeline",
|
||||
"QwenImageEditPlusModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
"QwenImageLayeredModularPipeline",
|
||||
"QwenImageModularPipeline",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -51,22 +39,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .modular_blocks_qwenimage import (
|
||||
AUTO_BLOCKS,
|
||||
QwenImageAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit import (
|
||||
EDIT_AUTO_BLOCKS,
|
||||
QwenImageEditAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_edit_plus import (
|
||||
EDIT_PLUS_AUTO_BLOCKS,
|
||||
QwenImageEditPlusAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage_layered import (
|
||||
LAYERED_AUTO_BLOCKS,
|
||||
QwenImageLayeredAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_qwenimage import QwenImageAutoBlocks
|
||||
from .modular_blocks_qwenimage_edit import QwenImageEditAutoBlocks
|
||||
from .modular_blocks_qwenimage_edit_plus import QwenImageEditPlusAutoBlocks
|
||||
from .modular_blocks_qwenimage_layered import QwenImageLayeredAutoBlocks
|
||||
from .modular_pipeline import (
|
||||
QwenImageEditModularPipeline,
|
||||
QwenImageEditPlusModularPipeline,
|
||||
|
||||
@@ -558,7 +558,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The initial random noised latents for the denoising process. Can be generated in prepare latents step.
|
||||
@@ -644,7 +644,7 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
@@ -725,7 +725,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
|
||||
Inputs:
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
latents (`Tensor`):
|
||||
The latents to use for the denoising process. Can be generated in prepare latents step.
|
||||
@@ -842,7 +842,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
@@ -917,7 +917,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shapes of the images latents, used for RoPE calculation
|
||||
"""
|
||||
|
||||
@@ -995,9 +995,9 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
batch_size (`int`, *optional*, defaults to 1):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
|
||||
be generated in input step.
|
||||
image_height (`List`):
|
||||
image_height (`list`):
|
||||
The heights of the reference images. Can be generated in input step.
|
||||
image_width (`List`):
|
||||
image_width (`list`):
|
||||
The widths of the reference images. Can be generated in input step.
|
||||
height (`int`):
|
||||
The height in pixels of the generated image.
|
||||
@@ -1009,11 +1009,11 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
txt_seq_lens (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
negative_txt_seq_lens (`list`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
"""
|
||||
|
||||
@@ -1123,11 +1123,11 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
|
||||
mask for the negative text embeddings. Can be generated from text_encoder step.
|
||||
|
||||
Outputs:
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shapes of the image latents, used for RoPE calculation
|
||||
txt_seq_lens (`List`):
|
||||
txt_seq_lens (`list`):
|
||||
The sequence lengths of the prompt embeds, used for RoPE calculation
|
||||
negative_txt_seq_lens (`List`):
|
||||
negative_txt_seq_lens (`list`):
|
||||
The sequence lengths of the negative prompt embeds, used for RoPE calculation
|
||||
additional_t_cond (`Tensor`):
|
||||
The additional t cond, used for RoPE calculation
|
||||
@@ -1238,7 +1238,7 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
|
||||
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
|
||||
|
||||
Outputs:
|
||||
controlnet_keep (`List`):
|
||||
controlnet_keep (`list`):
|
||||
The controlnet keep values
|
||||
"""
|
||||
|
||||
|
||||
@@ -191,7 +191,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
|
||||
step.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -268,7 +268,7 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
@@ -366,7 +366,7 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
@@ -436,12 +436,12 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
|
||||
the generated image tensor from decoders step
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
mask_overlay_kwargs (`dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
|
||||
@@ -518,11 +518,11 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
@@ -576,11 +576,11 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The number of denoising steps.
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
@@ -645,13 +645,13 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
controlnet_keep (`list`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
@@ -711,13 +711,13 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
|
||||
controlnet_keep (`List`):
|
||||
controlnet_keep (`list`):
|
||||
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
@@ -787,11 +787,11 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
@@ -846,11 +846,11 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
mask (`Tensor`):
|
||||
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
|
||||
@@ -910,11 +910,11 @@ class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
|
||||
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
|
||||
image_latents (`Tensor`):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
img_shapes (`List`):
|
||||
img_shapes (`list`):
|
||||
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
|
||||
|
||||
Outputs:
|
||||
|
||||
@@ -285,11 +285,11 @@ class QwenImageEditResizeStep(ModularPipelineBlocks):
|
||||
image_resize_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
"""
|
||||
|
||||
@@ -359,13 +359,13 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks):
|
||||
image_resize_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
"""
|
||||
|
||||
@@ -452,13 +452,13 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks):
|
||||
image_resize_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
resized_cond_image (`list`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
"""
|
||||
|
||||
@@ -1058,7 +1058,7 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
Inputs:
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
@@ -1072,7 +1072,7 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
mask_overlay_kwargs (`dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
"""
|
||||
|
||||
@@ -1177,7 +1177,7 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
mask_overlay_kwargs (`dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
"""
|
||||
|
||||
@@ -1256,7 +1256,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
|
||||
image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
@@ -1340,7 +1340,7 @@ class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks):
|
||||
image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized image. should be generated using a resize step
|
||||
|
||||
Outputs:
|
||||
@@ -1412,7 +1412,7 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
|
||||
image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized image. should be generated using a resize step
|
||||
|
||||
Outputs:
|
||||
|
||||
@@ -496,9 +496,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
|
||||
image latents used to guide the image generation. Can be generated from vae_encoder step.
|
||||
|
||||
Outputs:
|
||||
image_height (`List`):
|
||||
image_height (`list`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
image_width (`list`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
|
||||
@@ -119,7 +119,7 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
Inputs:
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
@@ -135,7 +135,7 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
mask_overlay_kwargs (`dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
@@ -164,7 +164,7 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
|
||||
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
@@ -476,9 +476,9 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -553,11 +553,11 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -632,11 +632,11 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -712,7 +712,7 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
control_guidance_start (`float`, *optional*, defaults to 0.0):
|
||||
When to start applying ControlNet.
|
||||
@@ -720,7 +720,7 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -802,7 +802,7 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
@@ -812,7 +812,7 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -894,7 +894,7 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
@@ -904,7 +904,7 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
When to stop applying ControlNet.
|
||||
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
|
||||
Scale for ControlNet conditioning.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -1032,7 +1032,7 @@ class QwenImageDecodeStep(SequentialPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -1061,12 +1061,12 @@ class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
mask_overlay_kwargs (`dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -1113,10 +1113,14 @@ AUTO_BLOCKS = InsertableDict(
|
||||
class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
|
||||
- for image-to-image generation, you need to provide `image`
|
||||
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
|
||||
- to run the controlnet workflow, you need to provide `control_image`
|
||||
- for text-to-image generation, all you need to provide is `prompt`
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `prompt`, `image`
|
||||
- `inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
- `controlnet_text2image`: requires `prompt`, `control_image`
|
||||
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
|
||||
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
|
||||
@@ -1134,7 +1138,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
Maximum sequence length for prompt encoding.
|
||||
mask_image (`Image`, *optional*):
|
||||
Mask image for inpainting.
|
||||
image (`Union[Image, List]`, *optional*):
|
||||
image (`Image | list`, *optional*):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
height (`int`, *optional*):
|
||||
The height in pixels of the generated image.
|
||||
@@ -1160,9 +1164,9 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -1183,12 +1187,12 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
Scale for ControlNet conditioning.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
mask_overlay_kwargs (`dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
@@ -1197,15 +1201,23 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = AUTO_BLOCKS.values()
|
||||
block_names = AUTO_BLOCKS.keys()
|
||||
|
||||
# Workflow map defines the trigger conditions for each workflow.
|
||||
# How to define:
|
||||
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
|
||||
# - currently, only supports `True` means the workflow triggers when the input is not None
|
||||
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"prompt": True, "image": True},
|
||||
"inpainting": {"prompt": True, "mask_image": True, "image": True},
|
||||
"controlnet_text2image": {"prompt": True, "control_image": True},
|
||||
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
|
||||
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
|
||||
+ "- for image-to-image generation, you need to provide `image`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
|
||||
@@ -67,7 +67,7 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
@@ -75,7 +75,7 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
@@ -115,13 +115,13 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
@@ -156,7 +156,7 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
mask_image (`Image`):
|
||||
Mask image for inpainting.
|
||||
@@ -166,13 +166,13 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
processed_mask_image (`Tensor`):
|
||||
The processed mask image
|
||||
mask_overlay_kwargs (`Dict`):
|
||||
mask_overlay_kwargs (`dict`):
|
||||
The kwargs for the postprocess step to apply the mask overlay
|
||||
image_latents (`Tensor`):
|
||||
The latent representation of the input image.
|
||||
@@ -450,9 +450,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -526,11 +526,11 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -627,7 +627,7 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -656,12 +656,12 @@ class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
|
||||
step.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
mask_overlay_kwargs (`dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -718,6 +718,11 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
|
||||
`padding_mask_crop`
|
||||
|
||||
|
||||
Supported workflows:
|
||||
- `image_conditioned`: requires `prompt`, `image`
|
||||
- `image_conditioned_inpainting`: requires `prompt`, `mask_image`, `image`
|
||||
|
||||
Components:
|
||||
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
|
||||
@@ -725,7 +730,7 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
@@ -751,28 +756,32 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
strength (`float`, *optional*, defaults to 0.9):
|
||||
Strength for img2img/inpainting.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
mask_overlay_kwargs (`Dict`, *optional*):
|
||||
mask_overlay_kwargs (`dict`, *optional*):
|
||||
The kwargs for the postprocess step to apply the mask overlay. generated in
|
||||
InpaintProcessImagesInputStep.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
model_name = "qwenimage-edit"
|
||||
block_classes = EDIT_AUTO_BLOCKS.values()
|
||||
block_names = EDIT_AUTO_BLOCKS.keys()
|
||||
_workflow_map = {
|
||||
"image_conditioned": {"prompt": True, "image": True},
|
||||
"image_conditioned_inpainting": {"prompt": True, "mask_image": True, "image": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
|
||||
@@ -58,7 +58,7 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
@@ -66,9 +66,9 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
|
||||
The prompt or prompts not to guide the image generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
resized_cond_image (`list`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
prompt_embeds (`Tensor`):
|
||||
The prompt embeddings.
|
||||
@@ -108,15 +108,15 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
generator (`Generator`, *optional*):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
Images resized to 1024x1024 target area for VAE encoding
|
||||
resized_cond_image (`List`):
|
||||
resized_cond_image (`list`):
|
||||
Images resized to 384x384 target area for VL text encoding
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
@@ -189,9 +189,9 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
|
||||
The negative prompt embeddings. (batch-expanded)
|
||||
negative_prompt_embeds_mask (`Tensor`):
|
||||
The negative prompt embeddings mask. (batch-expanded)
|
||||
image_height (`List`):
|
||||
image_height (`list`):
|
||||
The image heights calculated from the image latents dimension
|
||||
image_width (`List`):
|
||||
image_width (`list`):
|
||||
The image widths calculated from the image latents dimension
|
||||
height (`int`):
|
||||
if not provided, updated to image height
|
||||
@@ -253,9 +253,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -315,7 +315,7 @@ class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images. (tensor output of the vae decoder.)
|
||||
"""
|
||||
|
||||
@@ -357,7 +357,7 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation.
|
||||
@@ -375,9 +375,9 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
Pre-generated noisy latents for image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -385,7 +385,7 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
@@ -74,7 +74,7 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
|
||||
Maximum sequence length for prompt encoding.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
prompt (`str`):
|
||||
The prompt or prompts to guide image generation. If not provided, updated using image caption
|
||||
@@ -117,7 +117,7 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
(`AutoencoderKLQwenImage`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
@@ -125,7 +125,7 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
|
||||
Outputs:
|
||||
resized_image (`List`):
|
||||
resized_image (`list`):
|
||||
The resized images
|
||||
processed_image (`Tensor`):
|
||||
The processed image
|
||||
@@ -250,9 +250,9 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
Torch generator for deterministic generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -317,7 +317,7 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Union[Image, List]`):
|
||||
image (`Image | list`):
|
||||
Reference image(s) for denoising. Can be a single image or list of images.
|
||||
resolution (`int`, *optional*, defaults to 640):
|
||||
The target area to resize the image to, can be 1024 or 640
|
||||
@@ -339,9 +339,9 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
Number of layers to extract from the image
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps.
|
||||
sigmas (`List`, *optional*):
|
||||
sigmas (`list`, *optional*):
|
||||
Custom sigmas for the denoising process.
|
||||
attention_kwargs (`Dict`, *optional*):
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional kwargs for attention processors.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
@@ -349,7 +349,7 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
|
||||
Output format: 'pil', 'np', 'pt'.
|
||||
|
||||
Outputs:
|
||||
images (`List`):
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
|
||||
@@ -21,21 +21,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"AUTO_BLOCKS",
|
||||
"CONTROLNET_BLOCKS",
|
||||
"IMAGE2IMAGE_BLOCKS",
|
||||
"INPAINT_BLOCKS",
|
||||
"IP_ADAPTER_BLOCKS",
|
||||
"TEXT2IMAGE_BLOCKS",
|
||||
"StableDiffusionXLAutoBlocks",
|
||||
"StableDiffusionXLAutoControlnetStep",
|
||||
"StableDiffusionXLAutoDecodeStep",
|
||||
"StableDiffusionXLAutoIPAdapterStep",
|
||||
"StableDiffusionXLAutoVaeEncoderStep",
|
||||
]
|
||||
_import_structure["modular_blocks_stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -45,23 +31,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .encoders import (
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
)
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
AUTO_BLOCKS,
|
||||
CONTROLNET_BLOCKS,
|
||||
IMAGE2IMAGE_BLOCKS,
|
||||
INPAINT_BLOCKS,
|
||||
IP_ADAPTER_BLOCKS,
|
||||
TEXT2IMAGE_BLOCKS,
|
||||
StableDiffusionXLAutoBlocks,
|
||||
StableDiffusionXLAutoControlnetStep,
|
||||
StableDiffusionXLAutoDecodeStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
StableDiffusionXLAutoVaeEncoderStep,
|
||||
)
|
||||
from .modular_blocks_stable_diffusion_xl import StableDiffusionXLAutoBlocks
|
||||
from .modular_pipeline import StableDiffusionXLModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
StableDiffusionXLControlNetInputStep,
|
||||
StableDiffusionXLControlNetUnionInputStep,
|
||||
@@ -277,7 +277,161 @@ class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# ip-adapter, controlnet, text2img, img2img, inpainting
|
||||
# auto_docstring
|
||||
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion
|
||||
XL.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `image`, `prompt`
|
||||
- `inpainting`: requires `mask_image`, `image`, `prompt`
|
||||
- `controlnet_text2image`: requires `control_image`, `prompt`
|
||||
- `controlnet_image2image`: requires `control_image`, `image`, `prompt`
|
||||
- `controlnet_inpainting`: requires `control_image`, `mask_image`, `image`, `prompt`
|
||||
- `controlnet_union_text2image`: requires `control_image`, `control_mode`, `prompt`
|
||||
- `controlnet_union_image2image`: requires `control_image`, `control_mode`, `image`, `prompt`
|
||||
- `controlnet_union_inpainting`: requires `control_image`, `control_mode`, `mask_image`, `image`, `prompt`
|
||||
- `ip_adapter_text2image`: requires `ip_adapter_image`, `prompt`
|
||||
- `ip_adapter_image2image`: requires `ip_adapter_image`, `image`, `prompt`
|
||||
- `ip_adapter_inpainting`: requires `ip_adapter_image`, `mask_image`, `image`, `prompt`
|
||||
- `ip_adapter_controlnet_text2image`: requires `ip_adapter_image`, `control_image`, `prompt`
|
||||
- `ip_adapter_controlnet_image2image`: requires `ip_adapter_image`, `control_image`, `image`, `prompt`
|
||||
- `ip_adapter_controlnet_inpainting`: requires `ip_adapter_image`, `control_image`, `mask_image`, `image`,
|
||||
`prompt`
|
||||
- `ip_adapter_controlnet_union_text2image`: requires `ip_adapter_image`, `control_image`, `control_mode`,
|
||||
`prompt`
|
||||
- `ip_adapter_controlnet_union_image2image`: requires `ip_adapter_image`, `control_image`, `control_mode`,
|
||||
`image`, `prompt`
|
||||
- `ip_adapter_controlnet_union_inpainting`: requires `ip_adapter_image`, `control_image`, `control_mode`,
|
||||
`mask_image`, `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`CLIPTextModel`) text_encoder_2 (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`)
|
||||
tokenizer_2 (`CLIPTokenizer`) guider (`ClassifierFreeGuidance`) image_encoder
|
||||
(`CLIPVisionModelWithProjection`) feature_extractor (`CLIPImageProcessor`) unet (`UNet2DConditionModel`) vae
|
||||
(`AutoencoderKL`) image_processor (`VaeImageProcessor`) mask_processor (`VaeImageProcessor`) scheduler
|
||||
(`EulerDiscreteScheduler`) controlnet (`ControlNetUnionModel`) control_image_processor (`VaeImageProcessor`)
|
||||
|
||||
Configs:
|
||||
force_zeros_for_empty_prompt (default: True) requires_aesthetics_score (default: False)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
prompt_2 (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt_2 (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
cross_attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
clip_skip (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
ip_adapter_image (`Image | ndarray | Tensor | list | list | list`, *optional*):
|
||||
The image(s) to be used as ip adapter
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
mask_image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
padding_mask_crop (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
dtype (`dtype`, *optional*):
|
||||
The dtype of the model inputs
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
preprocess_kwargs (`dict | NoneType`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under
|
||||
`self.image_processor` in [diffusers.image_processor.VaeImageProcessor]
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
ip_adapter_embeds (`list`, *optional*):
|
||||
Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.
|
||||
negative_ip_adapter_embeds (`list`, *optional*):
|
||||
Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
denoising_end (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.3):
|
||||
TODO: Add description.
|
||||
denoising_start (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`None`):
|
||||
TODO: Add description.
|
||||
image_latents (`Tensor`, *optional*):
|
||||
The latents representing the reference image for image-to-image/inpainting generation. Can be generated
|
||||
in vae_encode step.
|
||||
mask (`Tensor`, *optional*):
|
||||
The mask for the inpainting generation. Can be generated in vae_encode step.
|
||||
masked_image_latents (`Tensor`, *optional*):
|
||||
The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be
|
||||
generated in vae_encode step.
|
||||
original_size (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
target_size (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_original_size (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_target_size (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
crops_coords_top_left (`None`, *optional*, defaults to (0, 0)):
|
||||
TODO: Add description.
|
||||
negative_crops_coords_top_left (`None`, *optional*, defaults to (0, 0)):
|
||||
TODO: Add description.
|
||||
aesthetic_score (`None`, *optional*, defaults to 6.0):
|
||||
TODO: Add description.
|
||||
negative_aesthetic_score (`None`, *optional*, defaults to 2.0):
|
||||
TODO: Add description.
|
||||
control_image (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
control_mode (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
control_guidance_start (`None`, *optional*, defaults to 0.0):
|
||||
TODO: Add description.
|
||||
control_guidance_end (`None`, *optional*, defaults to 1.0):
|
||||
TODO: Add description.
|
||||
controlnet_conditioning_scale (`None`, *optional*, defaults to 1.0):
|
||||
TODO: Add description.
|
||||
guess_mode (`None`, *optional*, defaults to False):
|
||||
TODO: Add description.
|
||||
crops_coords (`tuple | NoneType`, *optional*):
|
||||
The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can
|
||||
be generated in vae_encode step.
|
||||
controlnet_cond (`Tensor`, *optional*):
|
||||
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
|
||||
conditioning_scale (`float`, *optional*):
|
||||
The controlnet conditioning scale value to use for the denoising process. Can be generated in
|
||||
prepare_controlnet_inputs step.
|
||||
controlnet_keep (`list`, *optional*):
|
||||
The controlnet keep values to use for the denoising process. Can be generated in
|
||||
prepare_controlnet_inputs step.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
All conditional model inputs that need to be prepared with guider. It should contain
|
||||
prompt_embeds/negative_prompt_embeds, add_time_ids/negative_add_time_ids,
|
||||
pooled_prompt_embeds/negative_pooled_prompt_embeds, and ip_adapter_embeds/negative_ip_adapter_embeds
|
||||
(optional).please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when
|
||||
they are created and added to the pipeline state
|
||||
eta (`None`, *optional*, defaults to 0.0):
|
||||
TODO: Add description.
|
||||
output_type (`None`, *optional*, defaults to pil):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
StableDiffusionXLTextEncoderStep,
|
||||
StableDiffusionXLAutoIPAdapterStep,
|
||||
@@ -293,103 +447,66 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
|
||||
"decode",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
|
||||
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
|
||||
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
|
||||
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
|
||||
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
|
||||
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
|
||||
+ "- for text-to-image generation, all you need to provide is `prompt`"
|
||||
)
|
||||
|
||||
|
||||
# controlnet (input + denoise step)
|
||||
class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
StableDiffusionXLAutoControlNetInputStep,
|
||||
StableDiffusionXLAutoControlNetDenoiseStep,
|
||||
]
|
||||
block_names = ["controlnet_input", "controlnet_denoise"]
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
"inpainting": {"mask_image": True, "image": True, "prompt": True},
|
||||
"controlnet_text2image": {"control_image": True, "prompt": True},
|
||||
"controlnet_image2image": {"control_image": True, "image": True, "prompt": True},
|
||||
"controlnet_inpainting": {"control_image": True, "mask_image": True, "image": True, "prompt": True},
|
||||
"controlnet_union_text2image": {"control_image": True, "control_mode": True, "prompt": True},
|
||||
"controlnet_union_image2image": {"control_image": True, "control_mode": True, "image": True, "prompt": True},
|
||||
"controlnet_union_inpainting": {
|
||||
"control_image": True,
|
||||
"control_mode": True,
|
||||
"mask_image": True,
|
||||
"image": True,
|
||||
"prompt": True,
|
||||
},
|
||||
"ip_adapter_text2image": {"ip_adapter_image": True, "prompt": True},
|
||||
"ip_adapter_image2image": {"ip_adapter_image": True, "image": True, "prompt": True},
|
||||
"ip_adapter_inpainting": {"ip_adapter_image": True, "mask_image": True, "image": True, "prompt": True},
|
||||
"ip_adapter_controlnet_text2image": {"ip_adapter_image": True, "control_image": True, "prompt": True},
|
||||
"ip_adapter_controlnet_image2image": {
|
||||
"ip_adapter_image": True,
|
||||
"control_image": True,
|
||||
"image": True,
|
||||
"prompt": True,
|
||||
},
|
||||
"ip_adapter_controlnet_inpainting": {
|
||||
"ip_adapter_image": True,
|
||||
"control_image": True,
|
||||
"mask_image": True,
|
||||
"image": True,
|
||||
"prompt": True,
|
||||
},
|
||||
"ip_adapter_controlnet_union_text2image": {
|
||||
"ip_adapter_image": True,
|
||||
"control_image": True,
|
||||
"control_mode": True,
|
||||
"prompt": True,
|
||||
},
|
||||
"ip_adapter_controlnet_union_image2image": {
|
||||
"ip_adapter_image": True,
|
||||
"control_image": True,
|
||||
"control_mode": True,
|
||||
"image": True,
|
||||
"prompt": True,
|
||||
},
|
||||
"ip_adapter_controlnet_union_inpainting": {
|
||||
"ip_adapter_image": True,
|
||||
"control_image": True,
|
||||
"control_mode": True,
|
||||
"mask_image": True,
|
||||
"image": True,
|
||||
"prompt": True,
|
||||
},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Controlnet auto step that prepare the controlnet input and denoise the latents. "
|
||||
+ "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
|
||||
+ " (it should be replace at 'denoise' step)"
|
||||
)
|
||||
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL."
|
||||
|
||||
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("vae_encoder", StableDiffusionXLVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLDenoiseStep),
|
||||
("decode", StableDiffusionXLDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
INPAINT_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
|
||||
("input", StableDiffusionXLInputStep),
|
||||
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
|
||||
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
|
||||
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
|
||||
("denoise", StableDiffusionXLInpaintDenoiseStep),
|
||||
("decode", StableDiffusionXLInpaintDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
CONTROLNET_BLOCKS = InsertableDict(
|
||||
[
|
||||
("denoise", StableDiffusionXLAutoControlnetStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
IP_ADAPTER_BLOCKS = InsertableDict(
|
||||
[
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
]
|
||||
)
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", StableDiffusionXLTextEncoderStep),
|
||||
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
|
||||
("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
|
||||
("denoise", StableDiffusionXLCoreDenoiseStep),
|
||||
("decode", StableDiffusionXLAutoDecodeStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2img": TEXT2IMAGE_BLOCKS,
|
||||
"img2img": IMAGE2IMAGE_BLOCKS,
|
||||
"inpaint": INPAINT_BLOCKS,
|
||||
"controlnet": CONTROLNET_BLOCKS,
|
||||
"ip_adapter": IP_ADAPTER_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("images")]
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
@@ -37,7 +38,45 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
# auto_docstring
|
||||
class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
@@ -49,14 +88,11 @@ class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
return "denoise block that takes encoded conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# ====================
|
||||
@@ -64,7 +100,51 @@ class WanCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class WanBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline blocks for Wan2.1.
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer
|
||||
(`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) vae (`AutoencoderKLWan`) video_processor
|
||||
(`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`None`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
The output type of the decoded videos
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
@@ -75,9 +155,8 @@ class WanBlocks(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline blocks for Wan2.1.\n"
|
||||
+ "- `WanTextEncoderStep` is used to encode the text\n"
|
||||
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
|
||||
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
|
||||
)
|
||||
return "Modular pipeline blocks for Wan2.1."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
WanPrepareLatentsStep,
|
||||
WanSetTimestepsStep,
|
||||
@@ -38,7 +39,50 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
|
||||
|
||||
Configs:
|
||||
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
|
||||
noise stages.
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
@@ -50,14 +94,11 @@ class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
return "denoise block that takes encoded conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# ====================
|
||||
@@ -65,7 +106,55 @@ class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Wan22Blocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline for text-to-video using Wan2.2.
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer
|
||||
(`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`)
|
||||
transformer_2 (`WanTransformer3DModel`) vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Configs:
|
||||
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
|
||||
noise stages.
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`None`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
The output type of the decoded videos
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "wan"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
@@ -80,9 +169,8 @@ class Wan22Blocks(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for text-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
return "Modular pipeline for text-to-video using Wan2.2."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
@@ -40,7 +41,36 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent
|
||||
representation
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
first_frame_latents (`Tensor`):
|
||||
video latent representation with the first frame image condition
|
||||
image_condition_latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
@@ -56,7 +86,52 @@ class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
# auto_docstring
|
||||
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded text and image latent conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
|
||||
|
||||
Configs:
|
||||
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
|
||||
noise stages.
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_condition_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
@@ -75,15 +150,11 @@ class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
|
||||
)
|
||||
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# ====================
|
||||
@@ -91,7 +162,57 @@ class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
# ====================
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Modular pipeline for image-to-video using Wan2.2.
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler
|
||||
(`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
|
||||
|
||||
Configs:
|
||||
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
|
||||
noise stages.
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`None`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
The output type of the decoded videos
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
@@ -108,10 +229,8 @@ class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Modular pipeline for image-to-video using Wan2.2.\n"
|
||||
+ " - `WanTextEncoderStep` encodes the text\n"
|
||||
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
|
||||
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
|
||||
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
|
||||
)
|
||||
return "Modular pipeline for image-to-video using Wan2.2."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
WanAdditionalInputsStep,
|
||||
WanPrepareLatentsStep,
|
||||
@@ -45,7 +46,29 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
# auto_docstring
|
||||
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings
|
||||
|
||||
Components:
|
||||
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
image_embeds (`Tensor`):
|
||||
The image embeddings
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageEncoderStep]
|
||||
block_names = ["image_resize", "image_encoder"]
|
||||
@@ -56,7 +79,34 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
# auto_docstring
|
||||
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image
|
||||
embeddings
|
||||
|
||||
Components:
|
||||
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
last_image (`Image`):
|
||||
The last frameimage
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
resized_last_image (`Image`):
|
||||
TODO: Add description.
|
||||
image_embeds (`Tensor`):
|
||||
The image embeddings
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
|
||||
block_names = ["image_resize", "last_image_resize", "image_encoder"]
|
||||
@@ -67,7 +117,36 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 Auto Image Encoder
|
||||
# auto_docstring
|
||||
class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Image Encoder step that encode the image to generate the image embeddingsThis is an auto pipeline block that works
|
||||
for image2video tasks. - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided. -
|
||||
`WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is
|
||||
not provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
last_image (`Image`, *optional*):
|
||||
The last frameimage
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
resized_last_image (`Image`):
|
||||
TODO: Add description.
|
||||
image_embeds (`Tensor`):
|
||||
The image embeddings
|
||||
"""
|
||||
|
||||
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
|
||||
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
|
||||
block_trigger_inputs = ["last_image", "image"]
|
||||
@@ -90,7 +169,36 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 I2V (first frame only)
|
||||
# auto_docstring
|
||||
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent
|
||||
representation
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
first_frame_latents (`Tensor`):
|
||||
video latent representation with the first frame image condition
|
||||
image_condition_latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
|
||||
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
|
||||
@@ -101,7 +209,40 @@ class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 FLF2V (first and last frame)
|
||||
# auto_docstring
|
||||
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the
|
||||
latent conditions
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
last_image (`Image`):
|
||||
The last frameimage
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
resized_last_image (`Image`):
|
||||
TODO: Add description.
|
||||
first_last_frame_latents (`Tensor`):
|
||||
video latent representation with the first and last frame images condition
|
||||
image_condition_latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanImageResizeStep,
|
||||
@@ -117,7 +258,44 @@ class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 Auto Vae Encoder
|
||||
# auto_docstring
|
||||
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae Image Encoder step that encode the image to generate the image latentsThis is an auto pipeline block that works
|
||||
for image2video tasks. - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided. -
|
||||
`WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is not
|
||||
provided, step will be skipped.
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
last_image (`Image`, *optional*):
|
||||
The last frameimage
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
resized_image (`Image`):
|
||||
TODO: Add description.
|
||||
resized_last_image (`Image`):
|
||||
TODO: Add description.
|
||||
first_last_frame_latents (`Tensor`):
|
||||
video latent representation with the first and last frame images condition
|
||||
image_condition_latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
first_frame_latents (`Tensor`):
|
||||
video latent representation with the first frame image condition
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
|
||||
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
|
||||
@@ -141,7 +319,53 @@ class WanAutoVaeEncoderStep(AutoPipelineBlocks):
|
||||
|
||||
# wan2.1 I2V core denoise (support both I2V and FLF2V)
|
||||
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
|
||||
# auto_docstring
|
||||
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded text and image latent conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`Tensor`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_frames (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_condition_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_embeds (`Tensor`):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
batch_size (`int`):
|
||||
Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt
|
||||
dtype (`dtype`):
|
||||
Data type of model tensor inputs (determined by `transformer.dtype`)
|
||||
latents (`Tensor`):
|
||||
The initial latents to use for the denoising process
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextInputStep,
|
||||
@@ -160,15 +384,7 @@ class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
|
||||
|
||||
|
||||
# ====================
|
||||
@@ -177,7 +393,64 @@ class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
|
||||
|
||||
# wan2.1 Image2Video Auto Blocks
|
||||
# auto_docstring
|
||||
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for image-to-video using Wan.
|
||||
|
||||
Supported workflows:
|
||||
- `image2video`: requires `image`, `prompt`
|
||||
- `flf2v`: requires `last_image`, `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`)
|
||||
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) vae (`AutoencoderKLWan`)
|
||||
video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler
|
||||
(`UniPCMultistepScheduler`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`None`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
image (`Image`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`int`, *optional*, defaults to 480):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*, defaults to 832):
|
||||
TODO: Add description.
|
||||
last_image (`Image`, *optional*):
|
||||
The last frameimage
|
||||
num_frames (`int`, *optional*, defaults to 81):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_videos_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_condition_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 50):
|
||||
TODO: Add description.
|
||||
timesteps (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
attention_kwargs (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_embeds (`Tensor`):
|
||||
TODO: Add description.
|
||||
output_type (`str`, *optional*, defaults to np):
|
||||
The output type of the decoded videos
|
||||
|
||||
Outputs:
|
||||
videos (`list`):
|
||||
The generated videos.
|
||||
"""
|
||||
|
||||
model_name = "wan-i2v"
|
||||
block_classes = [
|
||||
WanTextEncoderStep,
|
||||
@@ -194,10 +467,15 @@ class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
|
||||
"decode",
|
||||
]
|
||||
|
||||
_workflow_map = {
|
||||
"image2video": {"image": True, "prompt": True},
|
||||
"flf2v": {"last_image": True, "image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"Auto Modular pipeline for image-to-video using Wan.\n"
|
||||
+ "- for I2V workflow, all you need to provide is `image`"
|
||||
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
|
||||
)
|
||||
return "Auto Modular pipeline for image-to-video using Wan."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("videos")]
|
||||
|
||||
@@ -21,12 +21,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
|
||||
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
|
||||
_import_structure["modular_blocks"] = [
|
||||
"ALL_BLOCKS",
|
||||
"ZImageAutoBlocks",
|
||||
]
|
||||
_import_structure["modular_blocks_z_image"] = ["ZImageAutoBlocks"]
|
||||
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
@@ -36,12 +31,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .decoders import ZImageVaeDecoderStep
|
||||
from .encoders import ZImageTextEncoderStep
|
||||
from .modular_blocks import (
|
||||
ALL_BLOCKS,
|
||||
ZImageAutoBlocks,
|
||||
)
|
||||
from .modular_blocks_z_image import ZImageAutoBlocks
|
||||
from .modular_pipeline import ZImageModularPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -1,191 +0,0 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import InsertableDict
|
||||
from .before_denoise import (
|
||||
ZImageAdditionalInputsStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImageTextInputStep,
|
||||
)
|
||||
from .decoders import ZImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
ZImageDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
ZImageTextEncoderStep,
|
||||
ZImageVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# z-image
|
||||
# text2image
|
||||
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
# z-image: image2image
|
||||
## denoise
|
||||
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"set_timesteps_with_strength",
|
||||
"prepare_latents_with_image",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return (
|
||||
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
|
||||
+ "This is a sequential pipeline blocks:\n"
|
||||
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
|
||||
+ " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
|
||||
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
|
||||
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
|
||||
+ " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n"
|
||||
+ " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n"
|
||||
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
|
||||
)
|
||||
|
||||
|
||||
## auto blocks
|
||||
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageImage2ImageCoreDenoiseStep,
|
||||
ZImageCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2image", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image and image2image tasks."
|
||||
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
|
||||
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
|
||||
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
block_classes = [ZImageVaeImageEncoderStep]
|
||||
block_names = ["vae_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+"This is an auto pipeline block that works for image2image tasks."
|
||||
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
|
||||
+" - if `image` is not provided, step will be skipped."
|
||||
|
||||
|
||||
class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||
block_classes = [
|
||||
ZImageTextEncoderStep,
|
||||
ZImageAutoVaeImageEncoderStep,
|
||||
ZImageAutoDenoiseStep,
|
||||
ZImageVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n"
|
||||
+" - for text-to-image generation, all you need to provide is `prompt`\n"
|
||||
+" - for image-to-image generation, you need to provide `image`\n"
|
||||
+" - if `image` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# presets
|
||||
TEXT2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
("set_timesteps", ZImageSetTimestepsStep),
|
||||
("denoise", ZImageDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
IMAGE2IMAGE_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_encoder", ZImageVaeImageEncoderStep),
|
||||
("input", ZImageTextInputStep),
|
||||
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
|
||||
("prepare_latents", ZImagePrepareLatentsStep),
|
||||
("set_timesteps", ZImageSetTimestepsStep),
|
||||
("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep),
|
||||
("prepare_latents_with_image", ZImagePrepareLatentswithImageStep),
|
||||
("denoise", ZImageDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
AUTO_BLOCKS = InsertableDict(
|
||||
[
|
||||
("text_encoder", ZImageTextEncoderStep),
|
||||
("vae_encoder", ZImageAutoVaeImageEncoderStep),
|
||||
("denoise", ZImageAutoDenoiseStep),
|
||||
("decode", ZImageVaeDecoderStep),
|
||||
]
|
||||
)
|
||||
|
||||
ALL_BLOCKS = {
|
||||
"text2image": TEXT2IMAGE_BLOCKS,
|
||||
"image2image": IMAGE2IMAGE_BLOCKS,
|
||||
"auto": AUTO_BLOCKS,
|
||||
}
|
||||
@@ -0,0 +1,334 @@
|
||||
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ...utils import logging
|
||||
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
from ..modular_pipeline_utils import OutputParam
|
||||
from .before_denoise import (
|
||||
ZImageAdditionalInputsStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImageTextInputStep,
|
||||
)
|
||||
from .decoders import ZImageVaeDecoderStep
|
||||
from .denoise import (
|
||||
ZImageDenoiseStep,
|
||||
)
|
||||
from .encoders import (
|
||||
ZImageTextEncoderStep,
|
||||
ZImageVaeImageEncoderStep,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
# ====================
|
||||
# 1. DENOISE
|
||||
# ====================
|
||||
|
||||
|
||||
# text2image: inputs(text) -> set_timesteps -> prepare_latents -> denoise
|
||||
# auto_docstring
|
||||
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`list`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`list`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`int`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 9):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "denoise block that takes encoded conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# image2image: inputs(text + image_latents) -> prepare_latents -> set_timesteps -> set_timesteps_with_strength -> prepare_latents_with_image -> denoise
|
||||
# auto_docstring
|
||||
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
|
||||
"""
|
||||
denoise block that takes encoded text and image latent conditions and runs the denoising process.
|
||||
|
||||
Components:
|
||||
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`list`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`list`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`, *optional*, defaults to 9):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
ZImageTextInputStep,
|
||||
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
|
||||
ZImagePrepareLatentsStep,
|
||||
ZImageSetTimestepsStep,
|
||||
ZImageSetTimestepsWithStrengthStep,
|
||||
ZImagePrepareLatentswithImageStep,
|
||||
ZImageDenoiseStep,
|
||||
]
|
||||
block_names = [
|
||||
"input",
|
||||
"additional_inputs",
|
||||
"prepare_latents",
|
||||
"set_timesteps",
|
||||
"set_timesteps_with_strength",
|
||||
"prepare_latents_with_image",
|
||||
"denoise",
|
||||
]
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("latents")]
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Denoise step that iteratively denoise the latents. This is a auto pipeline block that works for text2image and
|
||||
image2image tasks. - `ZImageCoreDenoiseStep` (text2image) for text2image tasks. -
|
||||
`ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks. - if `image_latents` is provided,
|
||||
`ZImageImage2ImageCoreDenoiseStep` will be used.
|
||||
- if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.
|
||||
|
||||
Components:
|
||||
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
|
||||
(`ClassifierFreeGuidance`)
|
||||
|
||||
Inputs:
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
prompt_embeds (`list`):
|
||||
Pre-generated text embeddings. Can be generated from text_encoder step.
|
||||
negative_prompt_embeds (`list`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be generated from text_encoder step.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
|
||||
Outputs:
|
||||
latents (`Tensor`):
|
||||
Denoised latents.
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
ZImageImage2ImageCoreDenoiseStep,
|
||||
ZImageCoreDenoiseStep,
|
||||
]
|
||||
block_names = ["image2image", "text2image"]
|
||||
block_trigger_inputs = ["image_latents", None]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Denoise step that iteratively denoise the latents. "
|
||||
"This is a auto pipeline block that works for text2image and image2image tasks."
|
||||
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
|
||||
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
|
||||
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
|
||||
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
|
||||
)
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
|
||||
"""
|
||||
Vae Image Encoder step that encode the image to generate the image latents
|
||||
|
||||
Components:
|
||||
vae (`AutoencoderKL`) image_processor (`VaeImageProcessor`)
|
||||
|
||||
Inputs:
|
||||
image (`Image`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
|
||||
Outputs:
|
||||
image_latents (`Tensor`):
|
||||
video latent representation with the first frame image condition
|
||||
"""
|
||||
|
||||
block_classes = [ZImageVaeImageEncoderStep]
|
||||
block_names = ["vae_encoder"]
|
||||
block_trigger_inputs = ["image"]
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Vae Image Encoder step that encode the image to generate the image latents"
|
||||
+"This is an auto pipeline block that works for image2image tasks."
|
||||
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
|
||||
+" - if `image` is not provided, step will be skipped."
|
||||
|
||||
|
||||
# auto_docstring
|
||||
class ZImageAutoBlocks(SequentialPipelineBlocks):
|
||||
"""
|
||||
Auto Modular pipeline for text-to-image and image-to-image using ZImage.
|
||||
|
||||
Supported workflows:
|
||||
- `text2image`: requires `prompt`
|
||||
- `image2image`: requires `image`, `prompt`
|
||||
|
||||
Components:
|
||||
text_encoder (`Qwen3Model`) tokenizer (`Qwen2Tokenizer`) guider (`ClassifierFreeGuidance`) vae
|
||||
(`AutoencoderKL`) image_processor (`VaeImageProcessor`) transformer (`ZImageTransformer2DModel`) scheduler
|
||||
(`FlowMatchEulerDiscreteScheduler`)
|
||||
|
||||
Inputs:
|
||||
prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
negative_prompt (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
max_sequence_length (`None`, *optional*, defaults to 512):
|
||||
TODO: Add description.
|
||||
image (`Image`, *optional*):
|
||||
TODO: Add description.
|
||||
height (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
width (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
generator (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
num_images_per_prompt (`None`, *optional*, defaults to 1):
|
||||
TODO: Add description.
|
||||
image_latents (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
latents (`Tensor | NoneType`):
|
||||
TODO: Add description.
|
||||
num_inference_steps (`None`):
|
||||
TODO: Add description.
|
||||
sigmas (`None`, *optional*):
|
||||
TODO: Add description.
|
||||
strength (`None`, *optional*, defaults to 0.6):
|
||||
TODO: Add description.
|
||||
**denoiser_input_fields (`None`, *optional*):
|
||||
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
|
||||
output_type (`str`, *optional*, defaults to pil):
|
||||
The type of the output images, can be 'pil', 'np', 'pt'
|
||||
|
||||
Outputs:
|
||||
images (`list`):
|
||||
Generated images.
|
||||
"""
|
||||
|
||||
block_classes = [
|
||||
ZImageTextEncoderStep,
|
||||
ZImageAutoVaeImageEncoderStep,
|
||||
ZImageAutoDenoiseStep,
|
||||
ZImageVaeDecoderStep,
|
||||
]
|
||||
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
|
||||
_workflow_map = {
|
||||
"text2image": {"prompt": True},
|
||||
"image2image": {"image": True, "prompt": True},
|
||||
}
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage."
|
||||
|
||||
@property
|
||||
def outputs(self):
|
||||
return [OutputParam.template("images")]
|
||||
@@ -6,7 +6,6 @@ from ..utils import (
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_flax_available,
|
||||
is_k_diffusion_available,
|
||||
is_librosa_available,
|
||||
is_note_seq_available,
|
||||
is_onnx_available,
|
||||
@@ -238,6 +237,7 @@ else:
|
||||
"EasyAnimateInpaintPipeline",
|
||||
"EasyAnimateControlPipeline",
|
||||
]
|
||||
_import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
|
||||
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
|
||||
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
||||
_import_structure["hunyuan_video"] = [
|
||||
@@ -466,21 +466,6 @@ else:
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils import (
|
||||
dummy_torch_and_transformers_and_k_diffusion_objects,
|
||||
)
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
|
||||
else:
|
||||
_import_structure["stable_diffusion_k_diffusion"] = [
|
||||
"StableDiffusionKDiffusionPipeline",
|
||||
"StableDiffusionXLKDiffusionPipeline",
|
||||
]
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
@@ -683,6 +668,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hidream_image import HiDreamImagePipeline
|
||||
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
|
||||
from .hunyuan_video import (
|
||||
@@ -901,17 +887,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionOnnxPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
|
||||
else:
|
||||
from .stable_diffusion_k_diffusion import (
|
||||
StableDiffusionKDiffusionPipeline,
|
||||
StableDiffusionXLKDiffusionPipeline,
|
||||
)
|
||||
|
||||
try:
|
||||
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
@@ -502,6 +502,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
text_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
# Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+)
|
||||
# otherwise use it directly (Transformers v4)
|
||||
if hasattr(prompt_embeds, "pooler_output"):
|
||||
prompt_embeds = prompt_embeds.pooler_output
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
prompt_embeds = prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
@@ -610,6 +614,10 @@ class AudioLDM2Pipeline(DiffusionPipeline):
|
||||
uncond_input_ids,
|
||||
attention_mask=negative_attention_mask,
|
||||
)
|
||||
# Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+)
|
||||
# otherwise use it directly (Transformers v4)
|
||||
if hasattr(negative_prompt_embeds, "pooler_output"):
|
||||
negative_prompt_embeds = negative_prompt_embeds.pooler_output
|
||||
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
|
||||
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
|
||||
# make sure that we attend to this single hidden-state
|
||||
|
||||
@@ -54,6 +54,7 @@ from .flux import (
|
||||
)
|
||||
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
|
||||
from .glm_image import GlmImagePipeline
|
||||
from .helios import HeliosPipeline, HeliosPyramidPipeline
|
||||
from .hunyuandit import HunyuanDiTPipeline
|
||||
from .kandinsky import (
|
||||
KandinskyCombinedPipeline,
|
||||
@@ -174,6 +175,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
||||
("cogview3", CogView3PlusPipeline),
|
||||
("cogview4", CogView4Pipeline),
|
||||
("glm_image", GlmImagePipeline),
|
||||
("helios", HeliosPipeline),
|
||||
("helios-pyramid", HeliosPyramidPipeline),
|
||||
("cogview4-control", CogView4ControlPipeline),
|
||||
("qwenimage", QwenImagePipeline),
|
||||
("qwenimage-controlnet", QwenImageControlNetPipeline),
|
||||
|
||||
@@ -287,6 +287,9 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = (
|
||||
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
|
||||
@@ -17,9 +17,6 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms
|
||||
import torchvision.transforms.functional
|
||||
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
@@ -54,11 +51,13 @@ else:
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
|
||||
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
|
||||
n_pad_frames = num_frames - video.shape[2]
|
||||
if n_pad_frames > 0:
|
||||
last_frame = video[:, :, -1:, :, :]
|
||||
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
|
||||
elif num_frames < video.shape[2]:
|
||||
video = video[:, :, :num_frames, :, :]
|
||||
return video
|
||||
|
||||
|
||||
@@ -134,8 +133,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
|
||||
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
|
||||
|
||||
>>> # Transfer inference with controls.
|
||||
>>> video = pipe(
|
||||
... video=input_video[:num_frames],
|
||||
... controls=controls,
|
||||
... controls_conditioning_scale=1.0,
|
||||
... prompt=prompt,
|
||||
@@ -149,7 +148,7 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for Cosmos Transfer2.5 base model.
|
||||
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
||||
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
||||
@@ -166,12 +165,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKLWan`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
|
||||
controlnet ([`CosmosControlNetModel`]):
|
||||
ControlNet used to condition generation on control inputs.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
|
||||
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
|
||||
_optional_components = ["safety_checker", "controlnet"]
|
||||
_optional_components = ["safety_checker"]
|
||||
_exclude_from_cpu_offload = ["safety_checker"]
|
||||
|
||||
def __init__(
|
||||
@@ -181,8 +182,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
transformer: CosmosTransformer3DModel,
|
||||
vae: AutoencoderKLWan,
|
||||
scheduler: UniPCMultistepScheduler,
|
||||
controlnet: Optional[CosmosControlNetModel],
|
||||
safety_checker: CosmosSafetyChecker = None,
|
||||
controlnet: CosmosControlNetModel,
|
||||
safety_checker: Optional[CosmosSafetyChecker] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -262,6 +263,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
)
|
||||
input_ids = (
|
||||
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
|
||||
)
|
||||
input_ids = torch.LongTensor(input_ids)
|
||||
input_ids_batch.append(input_ids)
|
||||
|
||||
@@ -381,10 +385,11 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
num_frames_in: int = 93,
|
||||
num_frames_out: int = 93,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
dtype: torch.dtype | None = None,
|
||||
device: torch.device | None = None,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
num_cond_latent_frames: int = 0,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(generator, list) and len(generator) != batch_size:
|
||||
raise ValueError(
|
||||
@@ -399,10 +404,14 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
W = width // self.vae_scale_factor_spatial
|
||||
shape = (B, C, T, H, W)
|
||||
|
||||
if num_frames_in == 0:
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
if latents is not None:
|
||||
if latents.shape[1:] != shape[1:]:
|
||||
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
else:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
|
||||
if num_frames_in == 0:
|
||||
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
|
||||
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
|
||||
|
||||
@@ -432,16 +441,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
cond_latents = (cond_latents - latents_mean) / latents_std
|
||||
|
||||
if latents is None:
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device=device, dtype=dtype)
|
||||
|
||||
padding_shape = (B, 1, T, H, W)
|
||||
ones_padding = latents.new_ones(padding_shape)
|
||||
zeros_padding = latents.new_zeros(padding_shape)
|
||||
|
||||
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
|
||||
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
|
||||
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
|
||||
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
|
||||
|
||||
return (
|
||||
@@ -451,34 +456,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
cond_indicator,
|
||||
)
|
||||
|
||||
def _encode_controls(
|
||||
self,
|
||||
controls: Optional[torch.Tensor],
|
||||
height: int,
|
||||
width: int,
|
||||
num_frames: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: torch.Generator | list[torch.Generator] | None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if controls is None:
|
||||
return None
|
||||
|
||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||
control_video = _maybe_pad_video(control_video, num_frames)
|
||||
|
||||
control_video = control_video.to(device=device, dtype=self.vae.dtype)
|
||||
control_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
|
||||
]
|
||||
control_latents = torch.cat(control_latents, dim=0).to(dtype)
|
||||
|
||||
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
|
||||
latents_std = self.latents_std.to(device=device, dtype=dtype)
|
||||
control_latents = (control_latents - latents_mean) / latents_std
|
||||
return control_latents
|
||||
|
||||
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
@@ -486,9 +464,25 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
width,
|
||||
prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
num_ar_conditional_frames=None,
|
||||
num_ar_latent_conditional_frames=None,
|
||||
num_frames_per_chunk=None,
|
||||
num_frames=None,
|
||||
conditional_frame_timestep=0.1,
|
||||
):
|
||||
if height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
|
||||
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
|
||||
)
|
||||
|
||||
if num_frames is not None and num_frames <= 0:
|
||||
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
|
||||
|
||||
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
|
||||
raise ValueError(
|
||||
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
|
||||
f"{conditional_frame_timestep}."
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
@@ -509,6 +503,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
|
||||
raise ValueError(
|
||||
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
|
||||
)
|
||||
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
|
||||
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
|
||||
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
|
||||
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
|
||||
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
|
||||
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
|
||||
|
||||
if num_ar_latent_conditional_frames is not None:
|
||||
num_ar_conditional_frames = max(
|
||||
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
|
||||
)
|
||||
|
||||
min_chunk_len = self.vae_scale_factor_temporal + 1
|
||||
if num_frames_per_chunk < min_chunk_len:
|
||||
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
|
||||
num_frames_per_chunk = min_chunk_len
|
||||
|
||||
max_frames_by_rope = None
|
||||
if getattr(self.transformer.config, "max_size", None) is not None:
|
||||
max_frames_by_rope = max(
|
||||
size // patch
|
||||
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
|
||||
)
|
||||
if num_frames_per_chunk > max_frames_by_rope:
|
||||
raise ValueError(
|
||||
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
|
||||
"Please reduce `num_frames_per_chunk`."
|
||||
)
|
||||
|
||||
if num_ar_conditional_frames >= num_frames_per_chunk:
|
||||
raise ValueError(
|
||||
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
|
||||
)
|
||||
|
||||
return num_frames_per_chunk
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
@@ -533,23 +567,22 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
image: PipelineImageInput | None = None,
|
||||
video: List[PipelineImageInput] | None = None,
|
||||
controls: PipelineImageInput | List[PipelineImageInput],
|
||||
controls_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
prompt: Union[str, List[str]] | None = None,
|
||||
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
|
||||
height: int = 704,
|
||||
width: int | None = None,
|
||||
num_frames: int = 93,
|
||||
width: Optional[int] = None,
|
||||
num_frames: Optional[int] = None,
|
||||
num_frames_per_chunk: int = 93,
|
||||
num_inference_steps: int = 36,
|
||||
guidance_scale: float = 3.0,
|
||||
num_videos_per_prompt: Optional[int] = 1,
|
||||
generator: torch.Generator | list[torch.Generator] | None = None,
|
||||
latents: torch.Tensor | None = None,
|
||||
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
|
||||
controls_conditioning_scale: float | list[float] = 1.0,
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
negative_prompt_embeds: torch.Tensor | None = None,
|
||||
output_type: str = "pil",
|
||||
num_videos_per_prompt: int = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Optional[
|
||||
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
|
||||
@@ -557,24 +590,26 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 512,
|
||||
conditional_frame_timestep: float = 0.1,
|
||||
num_ar_conditional_frames: Optional[int] = 1,
|
||||
num_ar_latent_conditional_frames: Optional[int] = None,
|
||||
):
|
||||
r"""
|
||||
The call function to the pipeline for generation. Supports three modes:
|
||||
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
|
||||
are pre-computed.
|
||||
|
||||
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
|
||||
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
|
||||
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
|
||||
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
|
||||
(default) then the number of output frames will match the input `controls`.
|
||||
|
||||
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
|
||||
above in "*2Image mode").
|
||||
|
||||
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
|
||||
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
|
||||
denoising loop. In addition, when auto-regressive inference is performed, the previous
|
||||
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
|
||||
inference loops.
|
||||
|
||||
Args:
|
||||
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
|
||||
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
|
||||
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
|
||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
|
||||
Control image or video input used by the ControlNet.
|
||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
|
||||
height (`int`, defaults to `704`):
|
||||
@@ -582,9 +617,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
width (`int`, *optional*):
|
||||
The width in pixels of the generated image. If not provided, this will be determined based on the
|
||||
aspect ratio of the input and the provided height.
|
||||
num_frames (`int`, defaults to `93`):
|
||||
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
|
||||
num_inference_steps (`int`, defaults to `35`):
|
||||
num_frames (`int`, *optional*):
|
||||
Number of output frames. Defaults to `None` to output the same number of frames as the input
|
||||
`controls`.
|
||||
num_inference_steps (`int`, defaults to `36`):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, defaults to `3.0`):
|
||||
@@ -598,13 +634,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
||||
generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor is generated by sampling using the supplied random `generator`.
|
||||
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
|
||||
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
|
||||
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
|
||||
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
|
||||
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
|
||||
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
|
||||
sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
@@ -627,7 +659,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
max_sequence_length (`int`, defaults to `512`):
|
||||
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
|
||||
the prompt is shorter than this length, it will be padded.
|
||||
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
|
||||
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
|
||||
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
|
||||
|
||||
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||
controls is > num_frames_per_chunk
|
||||
num_ar_latent_conditional_frames (`int`, *optional*):
|
||||
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
|
||||
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
|
||||
|
||||
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
|
||||
controls is > num_frames_per_chunk
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
@@ -647,21 +690,40 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
|
||||
if width is None:
|
||||
frame = image or video[0] if image or video else None
|
||||
if frame is None and controls is not None:
|
||||
frame = controls[0] if isinstance(controls, list) else controls
|
||||
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
|
||||
frame = controls[0]
|
||||
frame = controls[0] if isinstance(controls, list) else controls
|
||||
if isinstance(frame, list):
|
||||
frame = frame[0]
|
||||
if isinstance(frame, (torch.Tensor, np.ndarray)):
|
||||
if frame.ndim == 5:
|
||||
frame = frame[0, 0]
|
||||
elif frame.ndim == 4:
|
||||
frame = frame[0]
|
||||
|
||||
if frame is None:
|
||||
width = int((height + 16) * (1280 / 720))
|
||||
elif isinstance(frame, PIL.Image.Image):
|
||||
if isinstance(frame, PIL.Image.Image):
|
||||
width = int((height + 16) * (frame.width / frame.height))
|
||||
else:
|
||||
if frame.ndim != 3:
|
||||
raise ValueError("`controls` must contain 3D frames in CHW format.")
|
||||
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
|
||||
|
||||
# Check inputs. Raise error if not correct
|
||||
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
|
||||
num_frames_per_chunk = self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
num_ar_conditional_frames,
|
||||
num_ar_latent_conditional_frames,
|
||||
num_frames_per_chunk,
|
||||
num_frames,
|
||||
conditional_frame_timestep,
|
||||
)
|
||||
|
||||
if num_ar_latent_conditional_frames is not None:
|
||||
num_cond_latent_frames = num_ar_latent_conditional_frames
|
||||
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
|
||||
else:
|
||||
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._current_timestep = None
|
||||
@@ -706,102 +768,137 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
vae_dtype = self.vae.dtype
|
||||
transformer_dtype = self.transformer.dtype
|
||||
|
||||
img_context = torch.zeros(
|
||||
batch_size,
|
||||
self.transformer.config.img_context_num_tokens,
|
||||
self.transformer.config.img_context_dim_in,
|
||||
device=prompt_embeds.device,
|
||||
dtype=transformer_dtype,
|
||||
)
|
||||
encoder_hidden_states = (prompt_embeds, img_context)
|
||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||
|
||||
num_frames_in = None
|
||||
if image is not None:
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
|
||||
|
||||
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
|
||||
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
|
||||
video = video.unsqueeze(0)
|
||||
num_frames_in = 1
|
||||
elif video is None:
|
||||
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
|
||||
num_frames_in = 0
|
||||
else:
|
||||
num_frames_in = len(video)
|
||||
|
||||
if batch_size != 1:
|
||||
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
|
||||
|
||||
assert video is not None
|
||||
video = self.video_processor.preprocess_video(video, height, width)
|
||||
|
||||
# pad with last frame (for video2world)
|
||||
num_frames_out = num_frames
|
||||
video = _maybe_pad_video(video, num_frames_out)
|
||||
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
|
||||
|
||||
video = video.to(device=device, dtype=vae_dtype)
|
||||
|
||||
num_channels_latents = self.transformer.config.in_channels - 1
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=num_channels_latents,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=num_frames_in,
|
||||
num_frames_out=num_frames,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
|
||||
controls_latents = None
|
||||
if controls is not None:
|
||||
controls_latents = self._encode_controls(
|
||||
controls,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames=num_frames,
|
||||
if getattr(self.transformer.config, "img_context_dim_in", None):
|
||||
img_context = torch.zeros(
|
||||
batch_size,
|
||||
self.transformer.config.img_context_num_tokens,
|
||||
self.transformer.config.img_context_dim_in,
|
||||
device=prompt_embeds.device,
|
||||
dtype=transformer_dtype,
|
||||
device=device,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
if num_videos_per_prompt > 1:
|
||||
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
encoder_hidden_states = (prompt_embeds, img_context)
|
||||
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
|
||||
else:
|
||||
encoder_hidden_states = prompt_embeds
|
||||
neg_encoder_hidden_states = negative_prompt_embeds
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
control_video = self.video_processor.preprocess_video(controls, height, width)
|
||||
if control_video.shape[0] != batch_size:
|
||||
if control_video.shape[0] == 1:
|
||||
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
num_frames_out = control_video.shape[2]
|
||||
if num_frames is not None:
|
||||
num_frames_out = min(num_frames_out, num_frames)
|
||||
|
||||
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
|
||||
|
||||
# chunk information
|
||||
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
|
||||
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
|
||||
chunk_idxs = [
|
||||
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
|
||||
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
|
||||
]
|
||||
|
||||
video_chunks = []
|
||||
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
|
||||
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
|
||||
|
||||
def decode_latents(latents):
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
|
||||
return video
|
||||
|
||||
latents_arg = latents
|
||||
initial_num_cond_latent_frames = 0
|
||||
latent_chunks = []
|
||||
num_chunks = len(chunk_idxs)
|
||||
total_steps = num_inference_steps * num_chunks
|
||||
with self.progress_bar(total=total_steps) as progress_bar:
|
||||
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
|
||||
if chunk_idx == 0:
|
||||
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
|
||||
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
|
||||
else:
|
||||
prev_output = video_chunks[-1].clone()
|
||||
if num_ar_conditional_frames > 0:
|
||||
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
|
||||
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
|
||||
else:
|
||||
prev_output.fill_(-1)
|
||||
|
||||
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
|
||||
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
|
||||
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
|
||||
video=chunk_video,
|
||||
batch_size=batch_size * num_videos_per_prompt,
|
||||
num_channels_latents=self.transformer.config.in_channels - 1,
|
||||
height=height,
|
||||
width=width,
|
||||
num_frames_in=chunk_video.shape[2],
|
||||
num_frames_out=num_frames_per_chunk,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
dtype=torch.float32,
|
||||
device=device,
|
||||
generator=generator,
|
||||
num_cond_latent_frames=initial_num_cond_latent_frames
|
||||
if chunk_idx == 0
|
||||
else num_cond_latent_frames,
|
||||
latents=latents_arg,
|
||||
)
|
||||
cond_mask = cond_mask.to(transformer_dtype)
|
||||
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
|
||||
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
|
||||
|
||||
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
|
||||
device=device, dtype=self.vae.dtype
|
||||
)
|
||||
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
|
||||
if isinstance(generator, list):
|
||||
controls_latents = [
|
||||
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
|
||||
for i in range(chunk_control_video.shape[0])
|
||||
]
|
||||
else:
|
||||
controls_latents = [
|
||||
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
|
||||
for vid in chunk_control_video
|
||||
]
|
||||
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
|
||||
|
||||
controls_latents = (controls_latents - latents_mean) / latents_std
|
||||
|
||||
# Denoising loop
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
self._num_timesteps = len(timesteps)
|
||||
|
||||
gt_velocity = (latents - cond_latent) * cond_mask
|
||||
for i, t in enumerate(timesteps):
|
||||
if self.interrupt:
|
||||
continue
|
||||
|
||||
self._current_timestep = t.cpu().item()
|
||||
|
||||
# NOTE: assumes sigma(t) \in [0, 1]
|
||||
sigma_t = (
|
||||
torch.tensor(self.scheduler.sigmas[i].item())
|
||||
.unsqueeze(0)
|
||||
.to(device=device, dtype=transformer_dtype)
|
||||
)
|
||||
|
||||
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
|
||||
in_latents = in_latents.to(transformer_dtype)
|
||||
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
@@ -814,20 +911,18 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
control_blocks = None
|
||||
if controls_latents is not None and self.controlnet is not None:
|
||||
if self.do_classifier_free_guidance:
|
||||
control_output = self.controlnet(
|
||||
controls_latents=controls_latents,
|
||||
latents=in_latents,
|
||||
@@ -840,46 +935,50 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
)
|
||||
control_blocks = control_output[0]
|
||||
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
noise_pred_neg = self.transformer(
|
||||
hidden_states=in_latents,
|
||||
timestep=in_timestep,
|
||||
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
|
||||
block_controlnet_hidden_states=control_blocks,
|
||||
condition_mask=cond_mask,
|
||||
padding_mask=padding_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
|
||||
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
|
||||
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
|
||||
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
# call the callback, if provided
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
video_chunks.append(decode_latents(latents).detach().cpu())
|
||||
latent_chunks.append(latents.detach().cpu())
|
||||
|
||||
self._current_timestep = None
|
||||
|
||||
if not output_type == "latent":
|
||||
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
|
||||
latents_std = self.latents_std.to(latents.device, latents.dtype)
|
||||
latents = latents * latents_std + latents_mean
|
||||
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
|
||||
video = self._match_num_frames(video, num_frames)
|
||||
video_chunks = [
|
||||
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
|
||||
for chunk_idx, chunk in enumerate(video_chunks)
|
||||
]
|
||||
video = torch.cat(video_chunks, dim=2)
|
||||
video = video[:, :, :num_frames_out, ...]
|
||||
|
||||
assert self.safety_checker is not None
|
||||
self.safety_checker.to(device)
|
||||
@@ -896,7 +995,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
|
||||
video = self.video_processor.postprocess_video(video, output_type=output_type)
|
||||
else:
|
||||
video = latents
|
||||
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
|
||||
latent_chunks = [
|
||||
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
|
||||
for chunk_idx, chunk in enumerate(latent_chunks)
|
||||
]
|
||||
video = torch.cat(latent_chunks, dim=2)
|
||||
video = video[:, :, :latent_T, ...]
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
@@ -905,19 +1010,3 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
|
||||
return (video,)
|
||||
|
||||
return CosmosPipelineOutput(frames=video)
|
||||
|
||||
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
|
||||
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
|
||||
return video
|
||||
|
||||
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
|
||||
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
|
||||
|
||||
current_frames = video.shape[2]
|
||||
if current_frames < target_num_frames:
|
||||
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
|
||||
video = torch.cat([video, pad], dim=2)
|
||||
elif current_frames > target_num_frames:
|
||||
video = video[:, :, :target_num_frames]
|
||||
|
||||
return video
|
||||
|
||||
48
src/diffusers/pipelines/helios/__init__.py
Normal file
48
src/diffusers/pipelines/helios/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_helios"] = ["HeliosPipeline"]
|
||||
_import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"]
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import *
|
||||
else:
|
||||
from .pipeline_helios import HeliosPipeline
|
||||
from .pipeline_helios_pyramid import HeliosPyramidPipeline
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user