Compare commits

..

1 Commits

Author SHA1 Message Date
yiyixuxu
1fbaaf3b64 up 2026-02-14 20:14:54 +01:00
247 changed files with 4737 additions and 15858 deletions

View File

@@ -62,6 +62,20 @@ jobs:
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}
# TODO: enable this once the connection problem has been resolved.
- name: Update benchmarking results to DB
env:
PGDATABASE: metrics
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
run: |
git config --global --add safe.directory /__w/diffusers/diffusers
commit_id=$GITHUB_SHA
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}

View File

@@ -92,6 +92,7 @@ jobs:
runner: aws-general-8-plus
image: diffusers/diffusers-pytorch-cpu
report: torch_example_cpu
name: ${{ matrix.config.name }}
runs-on:
@@ -114,7 +115,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
- name: Environment
@@ -216,6 +218,8 @@ jobs:
run_lora_tests:
needs: [check_code_quality, check_repository_consistency]
strategy:
fail-fast: false
name: LoRA tests with PEFT main
@@ -243,8 +247,9 @@ jobs:
uv pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
uv pip install -U tokenizers
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -270,6 +275,6 @@ jobs:
if: ${{ always() }}
uses: actions/upload-artifact@v6
with:
name: pr_lora_test_reports
name: pr_main_test_reports
path: reports

View File

@@ -131,7 +131,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -198,10 +199,16 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -262,7 +269,8 @@ jobs:
nvidia-smi
- name: Install dependencies
run: |
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
uv pip install -e ".[quality,training]"
- name: Environment

View File

@@ -76,7 +76,8 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py
@@ -125,10 +126,16 @@ jobs:
- name: Install dependencies
run: |
# Install pkgs which depend on setuptools<81 for pkg_resources first with no build isolation
uv pip install pip==25.2 setuptools==80.10.2
uv pip install --no-build-isolation k-diffusion==0.0.12
uv pip install --upgrade pip setuptools
# Install the rest as normal
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
@@ -180,7 +187,8 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
#uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
uv pip uninstall transformers huggingface_hub && uv pip install transformers==4.57.1
- name: Environment
run: |
python utils/print_env.py

View File

@@ -41,7 +41,7 @@ jobs:
shell: arch -arch arm64 bash {0}
run: |
${CONDA_RUN} python -m pip install --upgrade pip uv
${CONDA_RUN} python -m uv pip install -e ".[quality]"
${CONDA_RUN} python -m uv pip install -e ".[quality,test]"
${CONDA_RUN} python -m uv pip install torch torchvision torchaudio
${CONDA_RUN} python -m uv pip install accelerate@git+https://github.com/huggingface/accelerate.git
${CONDA_RUN} python -m uv pip install transformers --upgrade

View File

@@ -54,6 +54,7 @@ jobs:
python -m pip install --upgrade pip
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
pip install -U transformers
- name: Build the dist files
run: python setup.py bdist_wheel && python setup.py sdist
@@ -68,8 +69,6 @@ jobs:
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://test.pypi.org/simple/ diffusers
pip install -U transformers
python utils/print_env.py
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"

View File

@@ -0,0 +1,166 @@
import argparse
import os
import sys
import gpustat
import pandas as pd
import psycopg2
import psycopg2.extras
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
register_adapter(dict, Json)
FINAL_CSV_FILENAME = "collated_results.csv"
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
BENCHMARKS_TABLE_NAME = "benchmarks"
MEASUREMENTS_TABLE_NAME = "model_measurements"
def _init_benchmark(conn, branch, commit_id, commit_msg):
gpu_stats = gpustat.GPUStatCollection.new_query()
metadata = {"gpu_name": gpu_stats[0]["name"]}
repository = "huggingface/diffusers"
with conn.cursor() as cur:
cur.execute(
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
(repository, branch, commit_id, commit_msg, metadata),
)
benchmark_id = cur.fetchone()[0]
print(f"Initialised benchmark #{benchmark_id}")
return benchmark_id
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"branch",
type=str,
help="The branch name on which the benchmarking is performed.",
)
parser.add_argument(
"commit_id",
type=str,
help="The commit hash on which the benchmarking is performed.",
)
parser.add_argument(
"commit_msg",
type=str,
help="The commit message associated with the commit, truncated to 70 characters.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
try:
conn = psycopg2.connect(
host=os.getenv("PGHOST"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD"),
)
print("DB connection established successfully.")
except Exception as e:
print(f"Problem during DB init: {e}")
sys.exit(1)
try:
benchmark_id = _init_benchmark(
conn=conn,
branch=args.branch,
commit_id=args.commit_id,
commit_msg=args.commit_msg,
)
except Exception as e:
print(f"Problem during initializing benchmark: {e}")
sys.exit(1)
cur = conn.cursor()
df = pd.read_csv(FINAL_CSV_FILENAME)
# Helper to cast values (or None) given a dtype
def _cast_value(val, dtype: str):
if pd.isna(val):
return None
if dtype == "text":
return str(val).strip()
if dtype == "float":
try:
return float(val)
except ValueError:
return None
if dtype == "bool":
s = str(val).strip().lower()
if s in ("true", "t", "yes", "1"):
return True
if s in ("false", "f", "no", "0"):
return False
if val in (1, 1.0):
return True
if val in (0, 0.0):
return False
return None
return val
try:
rows_to_insert = []
for _, row in df.iterrows():
scenario = _cast_value(row.get("scenario"), "text")
model_cls = _cast_value(row.get("model_cls"), "text")
num_params_B = _cast_value(row.get("num_params_B"), "float")
flops_G = _cast_value(row.get("flops_G"), "float")
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
fullgraph = _cast_value(row.get("fullgraph"), "bool")
mode = _cast_value(row.get("mode"), "text")
# If "github_sha" column exists in the CSV, cast it; else default to None
if "github_sha" in df.columns:
github_sha = _cast_value(row.get("github_sha"), "text")
else:
github_sha = None
measurements = {
"scenario": scenario,
"model_cls": model_cls,
"num_params_B": num_params_B,
"flops_G": flops_G,
"time_plain_s": time_plain_s,
"mem_plain_GB": mem_plain_GB,
"time_compile_s": time_compile_s,
"mem_compile_GB": mem_compile_GB,
"fullgraph": fullgraph,
"mode": mode,
"github_sha": github_sha,
}
rows_to_insert.append((benchmark_id, measurements))
# Batch-insert all rows
insert_sql = f"""
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
benchmark_id,
measurements
)
VALUES (%s, %s);
"""
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Exception: {e}")
sys.exit(1)

View File

@@ -194,8 +194,6 @@
title: Model accelerators and hardware
- isExpanded: false
sections:
- local: using-diffusers/helios
title: Helios
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -352,8 +350,6 @@
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/helios_transformer3d
title: HeliosTransformer3DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -460,8 +456,6 @@
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
@@ -631,6 +625,8 @@
title: Image-to-image
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
- local: api/pipelines/stable_diffusion/k_diffusion
title: K-Diffusion
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
@@ -679,8 +675,6 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios
title: Helios
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/hunyuan_video15
@@ -752,10 +746,6 @@
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
title: FlowMatchHeunDiscreteScheduler
- local: api/schedulers/helios_dmd
title: HeliosDMDScheduler
- local: api/schedulers/helios
title: HeliosScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm

View File

@@ -23,7 +23,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
@@ -87,10 +86,6 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
## HeliosLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin
## HunyuanVideoLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin

View File

@@ -1,89 +0,0 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoencoderRAE
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
The following RAE models are released and supported in Diffusers:
| Model | Encoder | Latent shape (224px input) |
|:------|:--------|:---------------------------|
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
## Loading a pretrained model
```python
from diffusers import AutoencoderRAE
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
```
## Encoding and decoding a real image
```python
import torch
from diffusers import AutoencoderRAE
from diffusers.utils import load_image
from torchvision.transforms.functional import to_tensor, to_pil_image
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
image = image.convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
with torch.no_grad():
latents = model.encode(x).latent # (1, 768, 16, 16)
recon = model.decode(latents).sample # (1, 3, 256, 256)
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
recon_image.save("recon.png")
```
## Latent normalization
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
```python
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
# Latent normalization is handled automatically inside encode/decode
# when the checkpoint config includes latents_mean/latents_std.
with torch.no_grad():
latents = model.encode(x).latent # normalized latents
recon = model.decode(latents).sample
```
## AutoencoderRAE
[[autodoc]] AutoencoderRAE
- encode
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -1,35 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HeliosTransformer3DModel
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
The model can be loaded with the following code snippet.
```python
from diffusers import HeliosTransformer3DModel
# Best Quality
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16)
# Intermediate Weight
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16)
# Best Efficiency
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HeliosTransformer3DModel
[[autodoc]] HeliosTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -14,8 +14,4 @@
## AutoPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks
## ConditionalPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.ConditionalPipelineBlocks
[[autodoc]] diffusers.modular_pipelines.modular_pipeline.AutoPipelineBlocks

View File

@@ -46,20 +46,6 @@ output = pipe(
output.save("output.png")
```
## Cosmos2_5_TransferPipeline
[[autodoc]] Cosmos2_5_TransferPipeline
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosTextToWorldPipeline
[[autodoc]] CosmosTextToWorldPipeline
@@ -84,6 +70,12 @@ output.save("output.png")
- all
- __call__
## Cosmos2_5_PredictBasePipeline
[[autodoc]] Cosmos2_5_PredictBasePipeline
- all
- __call__
## CosmosPipelineOutput
[[autodoc]] pipelines.cosmos.pipeline_output.CosmosPipelineOutput

View File

@@ -1,465 +0,0 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# Helios
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
* <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).
The following Helios models are supported in Diffusers:
- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.
- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.
- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.
> [!TIP]
> Click on the Helios models in the right sidebar for more examples of video generation.
### Optimizing Memory and Inference Speed
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
<hfoptions id="optimization">
<hfoption id="memory">
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
The Helios model below requires ~19GB of VRAM.
```py
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
# group-offloading
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
record_stream=True,
)
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
```
</hfoption>
<hfoption id="inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
```py
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# attention backend
# pipeline.transformer.set_attention_backend("flash")
pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs
# torch.compile
torch.backends.cudnn.benchmark = True
pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Base
The example below demonstrates how to use Helios-Base to generate video based on text, image or video.
<hfoptions id="Helios-Base usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Mid
The example below demonstrates how to use Helios-Mid to generate video based on text, image or video.
<hfoptions id="Helios-Mid usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPyramidPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPyramidPipeline.from_pretrained(
"BestWishYsh/Helios-Mid",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Distilled
The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.
<hfoptions id="Helios-Distilled usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPyramidPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPyramidPipeline.from_pretrained(
"BestWishYsh/Helios-Distilled",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
## HeliosPipeline
[[autodoc]] HeliosPipeline
- all
- __call__
## HeliosPyramidPipeline
[[autodoc]] HeliosPyramidPipeline
- all
- __call__
## HeliosPipelineOutput
[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput

View File

@@ -29,7 +29,7 @@ Qwen-Image comes in the following variants:
| Qwen-Image-Edit Plus | [Qwen/Qwen-Image-Edit-2509](https://huggingface.co/Qwen/Qwen-Image-Edit-2509) |
> [!TIP]
> See the [Caching](../../optimization/cache) guide to speed up inference by storing and reusing intermediate outputs.
> [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
## LoRA for faster inference
@@ -190,12 +190,6 @@ For detailed benchmark scripts and results, see [this gist](https://gist.github.
- all
- __call__
## QwenImageLayeredPipeline
[[autodoc]] QwenImageLayeredPipeline
- all
- __call__
## QwenImagePipelineOutput
[[autodoc]] pipelines.qwenimage.pipeline_output.QwenImagePipelineOutput

View File

@@ -0,0 +1,30 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
> [!WARNING]
> This pipeline is deprecated but it can still be used. However, we won't test the pipeline anymore and won't accept any changes to it. If you run into any issues, reinstall the last Diffusers version that supported this model.
# K-Diffusion
[k-diffusion](https://github.com/crowsonkb/k-diffusion) is a popular library created by [Katherine Crowson](https://github.com/crowsonkb/). We provide `StableDiffusionKDiffusionPipeline` and `StableDiffusionXLKDiffusionPipeline` that allow you to run Stable DIffusion with samplers from k-diffusion.
Note that most the samplers from k-diffusion are implemented in Diffusers and we recommend using existing schedulers. You can find a mapping between k-diffusion samplers and schedulers in Diffusers [here](https://huggingface.co/docs/diffusers/api/schedulers/overview)
## StableDiffusionKDiffusionPipeline
[[autodoc]] StableDiffusionKDiffusionPipeline
## StableDiffusionXLKDiffusionPipeline
[[autodoc]] StableDiffusionXLKDiffusionPipeline

View File

@@ -1,20 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HeliosScheduler
`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
## HeliosScheduler
[[autodoc]] HeliosScheduler
scheduling_helios

View File

@@ -1,20 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HeliosDMDScheduler
`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
## HeliosDMDScheduler
[[autodoc]] HeliosDMDScheduler
scheduling_helios_dmd

View File

@@ -121,7 +121,7 @@ from diffusers.modular_pipelines import AutoPipelineBlocks
class AutoImageBlocks(AutoPipelineBlocks):
# List of sub-block classes to choose from
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_classes = [block_inpaint_cls, block_i2i_cls, block_t2i_cls]
# Names for each block in the same order
block_names = ["inpaint", "img2img", "text2img"]
# Trigger inputs that determine which block to run
@@ -129,8 +129,8 @@ class AutoImageBlocks(AutoPipelineBlocks):
# - "image" triggers img2img workflow (but only if mask is not provided)
# - if none of above, runs the text2img workflow (default)
block_trigger_inputs = ["mask", "image", None]
# Description is extremely important for AutoPipelineBlocks
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
@@ -141,7 +141,7 @@ class AutoImageBlocks(AutoPipelineBlocks):
)
```
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, its conditional logic may be difficult to figure out if it isn't properly explained.
It is **very** important to include a `description` to avoid any confusion over how to run a block and what inputs are required. While [`~modular_pipelines.AutoPipelineBlocks`] are convenient, it's conditional logic may be difficult to figure out if it isn't properly explained.
Create an instance of `AutoImageBlocks`.
@@ -152,74 +152,5 @@ auto_blocks = AutoImageBlocks()
For more complex compositions, such as nested [`~modular_pipelines.AutoPipelineBlocks`] blocks when they're used as sub-blocks in larger pipelines, use the [`~modular_pipelines.SequentialPipelineBlocks.get_execution_blocks`] method to extract the a block that is actually run based on your input.
```py
auto_blocks.get_execution_blocks(mask=True)
```
## ConditionalPipelineBlocks
[`~modular_pipelines.AutoPipelineBlocks`] is a special case of [`~modular_pipelines.ConditionalPipelineBlocks`]. While [`~modular_pipelines.AutoPipelineBlocks`] selects blocks based on whether a trigger input is provided or not, [`~modular_pipelines.ConditionalPipelineBlocks`] is able to select a block based on custom selection logic provided in the `select_block` method.
Here is the same example written using [`~modular_pipelines.ConditionalPipelineBlocks`] directly:
```py
from diffusers.modular_pipelines import ConditionalPipelineBlocks
class AutoImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return (
"Pipeline generates images given different types of conditions!\n"
+ "This is an auto pipeline block that works for text2img, img2img and inpainting tasks.\n"
+ " - inpaint workflow is run when `mask` is provided.\n"
+ " - img2img workflow is run when `image` is provided (but only when `mask` is not provided).\n"
+ " - text2img workflow is run when neither `image` nor `mask` is provided.\n"
)
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name ("text2img")
```
The inputs listed in `block_trigger_inputs` are passed as keyword arguments to `select_block()`. When `select_block` returns `None`, it falls back to `default_block_name`. If `default_block_name` is also `None`, the entire conditional block is skipped — this is useful for optional processing steps that should only run when specific inputs are provided.
## Workflows
Pipelines that contain conditional blocks ([`~modular_pipelines.AutoPipelineBlocks`] or [`~modular_pipelines.ConditionalPipelineBlocks]`) can support multiple workflows — for example, our SDXL modular pipeline supports a dozen workflows all in one pipeline. But this also means it can be confusing for users to know what workflows are supported and how to run them. For pipeline builders, it's useful to be able to extract only the blocks relevant to a specific workflow.
We recommend defining a `_workflow_map` to give each workflow a name and explicitly list the inputs it requires.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class MyPipelineBlocks(SequentialPipelineBlocks):
block_classes = [TextEncoderBlock, AutoImageBlocks, DecodeBlock]
block_names = ["text_encoder", "auto_image", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
"inpaint": {"mask": True, "image": True, "prompt": True},
}
```
All of our built-in modular pipelines come with pre-defined workflows. The `available_workflows` property lists all supported workflows:
```py
pipeline_blocks = MyPipelineBlocks()
pipeline_blocks.available_workflows
# ['text2image', 'image2image', 'inpaint']
```
Retrieve a specific workflow with `get_workflow` to inspect and debug a specific block that executes the workflow.
```py
pipeline_blocks.get_workflow("inpaint")
auto_blocks.get_execution_blocks("mask")
```

View File

@@ -332,49 +332,4 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>
## Dependencies
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
```py
from diffusers.modular_pipelines import PipelineBlock
class MyCustomBlock(PipelineBlock):
_requirements = {
"transformers": ">=4.44.0",
"sentencepiece": ">=0.2.0"
}
```
When there are blocks with different requirements, Diffusers merges their requirements.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class BlockA(PipelineBlock):
_requirements = {"transformers": ">=4.44.0"}
# ...
class BlockB(PipelineBlock):
_requirements = {"sentencepiece": ">=0.2.0"}
# ...
pipe = SequentialPipelineBlocks.from_blocks_dict({
"block_a": BlockA,
"block_b": BlockB,
})
```
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
```md
# missing package
xyz-package was specified in the requirements but wasn't found in the current environment.
# version mismatch
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
```
</hfoptions>

View File

@@ -89,8 +89,10 @@ t2i_pipeline.guider
## Changing guider parameters
The guider parameters can be adjusted with the [`~ComponentSpec.create`] method and [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
The guider parameters can be adjusted with either the [`~ComponentSpec.create`] method or with [`~ModularPipeline.update_components`]. The example below changes the `guidance_scale` value.
<hfoptions id="switch">
<hfoption id="create">
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
@@ -98,6 +100,18 @@ guider = guider_spec.create(guidance_scale=10)
t2i_pipeline.update_components(guider=guider)
```
</hfoption>
<hfoption id="update_components">
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.config["guidance_scale"] = 10
t2i_pipeline.update_components(guider=guider_spec)
```
</hfoption>
</hfoptions>
## Uploading custom guiders
Call the [`~utils.PushToHubMixin.push_to_hub`] method on a custom guider to share it to the Hub.

View File

@@ -25,7 +25,9 @@ This guide explains how states work and how they connect blocks.
The [`~modular_pipelines.PipelineState`] is a global state container for all blocks. It maintains the complete runtime state of the pipeline and provides a structured way for blocks to read from and write to shared data.
[`~modular_pipelines.PipelineState`] stores all data in a `values` dict, which is a **mutable** state containing user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
There are two dict's in [`~modular_pipelines.PipelineState`] for structuring data.
- The `values` dict is a **mutable** state containing a copy of user provided input values and intermediate output values generated by blocks. If a block modifies an `input`, it will be reflected in the `values` dict after calling `set_block_state`.
```py
PipelineState(

View File

@@ -12,28 +12,27 @@ specific language governing permissions and limitations under the License.
# ModularPipeline
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`] into an executable pipeline that loads models and performs the computation steps defined in the blocks. It is the main interface for running a pipeline and the API is very similar to [`DiffusionPipeline`] but with a few key differences.
[`ModularPipeline`] converts [`~modular_pipelines.ModularPipelineBlocks`]'s into an executable pipeline that loads models and performs the computation steps defined in the block. It is the main interface for running a pipeline and it is very similar to the [`DiffusionPipeline`] API.
- **Loading is lazy.** With [`DiffusionPipeline`], [`~DiffusionPipeline.from_pretrained`] creates the pipeline and loads all models at the same time. With [`ModularPipeline`], creating and loading are two separate steps: [`~ModularPipeline.from_pretrained`] reads the configuration and knows where to load each component from, but doesn't actually load the model weights. You load the models later with [`~ModularPipeline.load_components`], which is where you pass loading arguments like `torch_dtype` and `quantization_config`.
- **Two ways to create a pipeline.** You can use [`~ModularPipeline.from_pretrained`] with an existing diffusers model repository — it automatically maps to the default pipeline blocks and then converts to a [`ModularPipeline`] with no extra setup. You can check the [modular_pipelines_directory](https://github.com/huggingface/diffusers/tree/main/src/diffusers/modular_pipelines) to see which models are currently supported. You can also assemble your own pipeline from [`ModularPipelineBlocks`] and convert it with the [`~ModularPipelineBlocks.init_pipeline`] method (see [Creating a pipeline](#creating-a-pipeline) for more details).
- **Running the pipeline is the same.** Once loaded, you call the pipeline with the same arguments you're used to. A single [`ModularPipeline`] can support multiple workflows (text-to-image, image-to-image, inpainting, etc.) when the pipeline blocks use [`AutoPipelineBlocks`](./auto_pipeline_blocks) to automatically select the workflow based on your inputs.
Below are complete examples for text-to-image, image-to-image, and inpainting with SDXL.
The main difference is to include an expected `output` argument in the pipeline.
<hfoptions id="example">
<hfoption id="text-to-image">
```py
import torch
from diffusers import ModularPipeline
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k").images[0]
image = pipeline(prompt="Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", output="images")[0]
image.save("modular_t2i_out.png")
```
@@ -42,17 +41,21 @@ image.save("modular_t2i_out.png")
```py
import torch
from diffusers import ModularPipeline
from diffusers.utils import load_image
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS
blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/sdxl-text2img.png"
init_image = load_image(url)
prompt = "a dog catching a frisbee in the jungle"
image = pipeline(prompt=prompt, image=init_image, strength=0.8).images[0]
image = pipeline(prompt=prompt, image=init_image, strength=0.8, output="images")[0]
image.save("modular_i2i_out.png")
```
@@ -61,10 +64,15 @@ image.save("modular_i2i_out.png")
```py
import torch
from diffusers import ModularPipeline
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import INPAINT_BLOCKS
from diffusers.utils import load_image
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS)
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
pipeline = blocks.init_pipeline(modular_repo_id)
pipeline.load_components(torch_dtype=torch.float16)
pipeline.to("cuda")
@@ -75,353 +83,276 @@ init_image = load_image(img_url)
mask_image = load_image(mask_url)
prompt = "A deep sea diver floating"
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85).images[0]
image.save("modular_inpaint_out.png")
image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, strength=0.85, output="images")[0]
image.save("moduar_inpaint_out.png")
```
</hfoption>
</hfoptions>
This guide will show you how to create a [`ModularPipeline`], manage its components, and run the pipeline.
This guide will show you how to create a [`ModularPipeline`] and manage the components in it.
## Adding blocks
Blocks are [`InsertableDict`] objects that can be inserted at specific positions, providing a flexible way to mix-and-match blocks.
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.insert`] on either the block class or `sub_blocks` attribute to add a block.
```py
# BLOCKS is dict of block classes, you need to add class to it
BLOCKS.insert("block_name", BlockClass, index)
# sub_blocks attribute contains instance, add a block instance to the attribute
t2i_blocks.sub_blocks.insert("block_name", block_instance, index)
```
Use [`~modular_pipelines.modular_pipeline_utils.InsertableDict.pop`] on either the block class or `sub_blocks` attribute to remove a block.
```py
# remove a block class from preset
BLOCKS.pop("text_encoder")
# split out a block instance on its own
text_encoder_block = t2i_blocks.sub_blocks.pop("text_encoder")
```
Swap blocks by setting the existing block to the new block.
```py
# Replace block class in preset
BLOCKS["prepare_latents"] = CustomPrepareLatents
# Replace in sub_blocks attribute using an block instance
t2i_blocks.sub_blocks["prepare_latents"] = CustomPrepareLatents()
```
## Creating a pipeline
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] with [`~ModularPipelineBlocks.init_pipeline`], or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
There are two ways to create a [`ModularPipeline`]. Assemble and create a pipeline from [`ModularPipelineBlocks`] or load an existing pipeline with [`~ModularPipeline.from_pretrained`].
You can also initialize a [`ComponentsManager`](./components_manager) to handle device placement and memory management. If you don't need automatic offloading, you can skip this and move the pipeline to your device manually with `pipeline.to("cuda")`.
You should also initialize a [`ComponentsManager`] to handle device placement and memory and component management.
> [!TIP]
> Refer to the [ComponentsManager](./components_manager) doc for more details about how it can help manage components across different workflows.
### init_pipeline
<hfoptions id="create">
<hfoption id="ModularPipelineBlocks">
[`~ModularPipelineBlocks.init_pipeline`] converts any [`ModularPipelineBlocks`] into a [`ModularPipeline`].
Let's define a minimal block to see how it works:
Use the [`~ModularPipelineBlocks.init_pipeline`] method to create a [`ModularPipeline`] from the component and configuration specifications. This method loads the *specifications* from a `modular_model_index.json` file, but it doesn't load the *models* yet.
```py
from transformers import CLIPTextModel
from diffusers.modular_pipelines import (
ComponentSpec,
ModularPipelineBlocks,
PipelineState,
)
from diffusers import ComponentsManager
from diffusers.modular_pipelines import SequentialPipelineBlocks
from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS
class MyBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
name="text_encoder",
type_hint=CLIPTextModel,
pretrained_model_name_or_path="openai/clip-vit-large-patch14",
),
]
t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS)
def __call__(self, components, state: PipelineState) -> PipelineState:
return components, state
modular_repo_id = "YiYiXu/modular-loader-t2i-0704"
components = ComponentsManager()
t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id, components_manager=components)
```
Call [`~ModularPipelineBlocks.init_pipeline`] to convert it into a pipeline. The `blocks` attribute on the pipeline is the blocks it was created from — it determines the expected inputs, outputs, and computation logic.
</hfoption>
<hfoption id="from_pretrained">
```py
block = MyBlock()
pipe = block.init_pipeline()
pipe.blocks
```
```
MyBlock {
"_class_name": "MyBlock",
"_diffusers_version": "0.37.0.dev0"
}
```
> [!WARNING]
> Blocks are mutable — you can freely add, remove, or swap blocks before creating a pipeline. However, once a pipeline is created, modifying `pipeline.blocks` won't affect the pipeline because it returns a copy. If you want a different block structure, create a new pipeline after modifying the blocks.
When you call [`~ModularPipelineBlocks.init_pipeline`] without a repository, it uses the `pretrained_model_name_or_path` defined in the block's [`ComponentSpec`] to determine where to load each component from. Printing the pipeline shows the component loading configuration.
```py
pipe
ModularPipeline {
"_blocks_class_name": "MyBlock",
"_class_name": "ModularPipeline",
"_diffusers_version": "0.37.0.dev0",
"text_encoder": [
null,
null,
{
"pretrained_model_name_or_path": "openai/clip-vit-large-patch14",
"revision": null,
"subfolder": "",
"type_hint": [
"transformers",
"CLIPTextModel"
],
"variant": null
}
]
}
```
If you pass a repository to [`~ModularPipelineBlocks.init_pipeline`], it overrides the loading path by matching your block's components against the pipeline config in that repository (`model_index.json` or `modular_model_index.json`).
In the example below, the `pretrained_model_name_or_path` will be updated to `"stabilityai/stable-diffusion-xl-base-1.0"`.
```py
pipe = block.init_pipeline("stabilityai/stable-diffusion-xl-base-1.0")
pipe
ModularPipeline {
"_blocks_class_name": "MyBlock",
"_class_name": "ModularPipeline",
"_diffusers_version": "0.37.0.dev0",
"text_encoder": [
null,
null,
{
"pretrained_model_name_or_path": "stabilityai/stable-diffusion-xl-base-1.0",
"revision": null,
"subfolder": "text_encoder",
"type_hint": [
"transformers",
"CLIPTextModel"
],
"variant": null
}
]
}
```
If a component in your block doesn't exist in the repository, it remains `null` and is skipped during [`~ModularPipeline.load_components`].
### from_pretrained
[`~ModularPipeline.from_pretrained`] is a convenient way to create a [`ModularPipeline`] without defining blocks yourself.
It works with three types of repositories.
**A regular diffusers repository.** Pass any supported model repository and it automatically maps to the default pipeline blocks. Currently supported models include SDXL, Wan, Qwen, Z-Image, Flux, and Flux2.
The [`~ModularPipeline.from_pretrained`] method creates a [`ModularPipeline`] from a modular repository on the Hub.
```py
from diffusers import ModularPipeline, ComponentsManager
components = ComponentsManager()
pipeline = ModularPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0", components_manager=components
)
pipeline = ModularPipeline.from_pretrained("YiYiXu/modular-loader-t2i-0704", components_manager=components)
```
**A modular repository.** These repositories contain a `modular_model_index.json` that specifies where to load each component from — the components can come from different repositories and the modular repository itself may not contain any model weights. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from one repository and the remaining components from another. See [Modular repository](#modular-repository) for more details on the format.
Add the `trust_remote_code` argument to load a custom [`ModularPipeline`].
```py
from diffusers import ModularPipeline, ComponentsManager
components = ComponentsManager()
pipeline = ModularPipeline.from_pretrained(
"diffusers/flux2-bnb-4bit-modular", components_manager=components
)
modular_repo_id = "YiYiXu/modular-diffdiff-0704"
diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True, components_manager=components)
```
**A modular repository with custom code.** Some repositories include custom pipeline blocks alongside the loading configuration. Add `trust_remote_code=True` to load them. See [Custom blocks](./custom_blocks) for how to create your own.
```py
from diffusers import ModularPipeline, ComponentsManager
components = ComponentsManager()
pipeline = ModularPipeline.from_pretrained(
"diffusers/Florence2-image-Annotator", trust_remote_code=True, components_manager=components
)
```
</hfoption>
</hfoptions>
## Loading components
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load components with [`~ModularPipeline.load_components`].
A [`ModularPipeline`] doesn't automatically instantiate with components. It only loads the configuration and component specifications. You can load all components with [`~ModularPipeline.load_components`] or only load specific components with [`~ModularPipeline.load_components`].
This will load all the components that have a valid loading spec.
<hfoptions id="load">
<hfoption id="load_components">
```py
import torch
pipeline.load_components(torch_dtype=torch.float16)
t2i_pipeline.load_components(torch_dtype=torch.float16)
t2i_pipeline.to("cuda")
```
You can also load specific components by name. The example below only loads the `text_encoder`.
</hfoption>
<hfoption id="load_components">
The example below only loads the UNet and VAE.
```py
pipeline.load_components(names=["text_encoder"], torch_dtype=torch.float16)
import torch
t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16)
```
After loading, printing the pipeline shows which components are loaded — the first two fields change from `null` to the component's library and class.
</hfoption>
</hfoptions>
Print the pipeline to inspect the loaded pretrained components.
```py
pipeline
t2i_pipeline
```
```
# text_encoder is loaded - shows library and class
"text_encoder": [
"transformers",
"CLIPTextModel",
{ ... }
]
This should match the `modular_model_index.json` file from the modular repository a pipeline is initialized from. If a pipeline doesn't need a component, it won't be included even if it exists in the modular repository.
# unet is not loaded yet - still null
To modify where components are loaded from, edit the `modular_model_index.json` file in the repository and change it to your desired loading path. The example below loads a UNet from a different repository.
```json
# original
"unet": [
null,
null,
{ ... }
null, null,
{
"repo": "stabilityai/stable-diffusion-xl-base-1.0",
"subfolder": "unet",
"variant": "fp16"
}
]
# modified
"unet": [
null, null,
{
"repo": "RunDiffusion/Juggernaut-XL-v9",
"subfolder": "unet",
"variant": "fp16"
}
]
```
Loading keyword arguments like `torch_dtype`, `variant`, `revision`, and `quantization_config` are passed through to `from_pretrained()` for each component. You can pass a single value to apply to all components, or a dict to set per-component values.
### Component loading status
The pipeline properties below provide more information about which components are loaded.
Use `component_names` to return all expected components.
```py
# apply bfloat16 to all components
pipeline.load_components(torch_dtype=torch.bfloat16)
# different dtypes per component
pipeline.load_components(torch_dtype={"transformer": torch.bfloat16, "default": torch.float32})
t2i_pipeline.component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor']
```
[`~ModularPipeline.load_components`] only loads components that haven't been loaded yet and have a valid loading spec. This means if you've already set a component on the pipeline, calling [`~ModularPipeline.load_components`] again won't reload it.
Use `null_component_names` to return components that aren't loaded yet. Load these components with [`~ModularPipeline.from_pretrained`].
```py
t2i_pipeline.null_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler']
```
Use `pretrained_component_names` to return components that will be loaded from pretrained models.
```py
t2i_pipeline.pretrained_component_names
['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae']
```
Use `config_component_names` to return components that are created with the default config (not loaded from a modular repository). Components from a config aren't included because they are already initialized during pipeline creation. This is why they aren't listed in `null_component_names`.
```py
t2i_pipeline.config_component_names
['guider', 'image_processor']
```
## Updating components
[`~ModularPipeline.update_components`] replaces a component on the pipeline with a new one. When a component is updated, the loading specifications are also updated in the pipeline config and [`~ModularPipeline.load_components`] will skip it on subsequent calls.
Components may be updated depending on whether it is a *pretrained component* or a *config component*.
### From AutoModel
> [!WARNING]
> A component may change from pretrained to config when updating a component. The component type is initially defined in a block's `expected_components` field.
You can pass a model object loaded with `AutoModel.from_pretrained()`. Models loaded this way are automatically tagged with their loading information.
A pretrained component is updated with [`ComponentSpec`] whereas a config component is updated by eihter passing the object directly or with [`ComponentSpec`].
The [`ComponentSpec`] shows `default_creation_method="from_pretrained"` for a pretrained component shows `default_creation_method="from_config` for a config component.
To update a pretrained component, create a [`ComponentSpec`] with the name of the component and where to load it from. Use the [`~ComponentSpec.load`] method to load the component.
```py
from diffusers import AutoModel
from diffusers import ComponentSpec, UNet2DConditionModel
unet = AutoModel.from_pretrained(
"RunDiffusion/Juggernaut-XL-v9", subfolder="unet", variant="fp16", torch_dtype=torch.float16
)
pipeline.update_components(unet=unet)
unet_spec = ComponentSpec(name="unet",type_hint=UNet2DConditionModel, repo="stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", variant="fp16")
unet = unet_spec.load(torch_dtype=torch.float16)
```
### From ComponentSpec
Use [`~ModularPipeline.get_component_spec`] to get a copy of the current component specification, modify it, and load a new component.
The [`~ModularPipeline.update_components`] method replaces the component with a new one.
```py
unet_spec = pipeline.get_component_spec("unet")
t2i_pipeline.update_components(unet=unet2)
```
When a component is updated, the loading specifications are also updated in the pipeline config.
### Component extraction and modification
When you use [`~ComponentSpec.load`], the new component maintains its loading specifications. This makes it possible to extract the specification and recreate the component.
```py
spec = ComponentSpec.from_component("unet", unet2)
spec
ComponentSpec(name='unet', type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>, description=None, config=None, repo='stabilityai/stable-diffusion-xl-base-1.0', subfolder='unet', variant='fp16', revision=None, default_creation_method='from_pretrained')
unet2_recreated = spec.load(torch_dtype=torch.float16)
```
The [`~ModularPipeline.get_component_spec`] method gets a copy of the current component specification to modify or update.
```py
unet_spec = t2i_pipeline.get_component_spec("unet")
unet_spec
ComponentSpec(
name='unet',
type_hint=<class 'diffusers.models.unets.unet_2d_condition.UNet2DConditionModel'>,
pretrained_model_name_or_path='RunDiffusion/Juggernaut-XL-v9',
subfolder='unet',
variant='fp16',
default_creation_method='from_pretrained'
)
# modify to load from a different repository
unet_spec.pretrained_model_name_or_path = "RunDiffusion/Juggernaut-XL-v9"
unet_spec.pretrained_model_name_or_path = "stabilityai/stable-diffusion-xl-base-1.0"
# load and update
# load component with modified spec
unet = unet_spec.load(torch_dtype=torch.float16)
pipeline.update_components(unet=unet)
```
You can also create a [`ComponentSpec`] from scratch.
Not all components are loaded from pretrained weights — some are created from a config (listed under `pipeline.config_component_names`). For these, use [`~ComponentSpec.create`] instead of [`~ComponentSpec.load`].
```py
guider_spec = pipeline.get_component_spec("guider")
guider_spec.config = {"guidance_scale": 5.0}
guider = guider_spec.create()
pipeline.update_components(guider=guider)
```
Or simply pass the object directly.
```py
from diffusers.guiders import ClassifierFreeGuidance
guider = ClassifierFreeGuidance(guidance_scale=5.0)
pipeline.update_components(guider=guider)
```
See the [Guiders](./guiders) guide for more details on available guiders and how to configure them.
## Splitting a pipeline into stages
Since blocks are composable, you can take a pipeline apart and reconstruct it into separate pipelines for each stage. The example below shows how we can separate the text encoder block from the rest of the pipeline, so you can encode the prompt independently and pass the embeddings to the main pipeline.
```py
from diffusers import ModularPipeline, ComponentsManager
import torch
device = "cuda"
dtype = torch.bfloat16
repo_id = "black-forest-labs/FLUX.2-klein-4B"
# get the blocks and separate out the text encoder
blocks = ModularPipeline.from_pretrained(repo_id).blocks
text_block = blocks.sub_blocks.pop("text_encoder")
# use ComponentsManager to handle offloading across multiple pipelines
manager = ComponentsManager()
manager.enable_auto_cpu_offload(device=device)
# create separate pipelines for each stage
text_encoder_pipeline = text_block.init_pipeline(repo_id, components_manager=manager)
pipeline = blocks.init_pipeline(repo_id, components_manager=manager)
# encode text
text_encoder_pipeline.load_components(torch_dtype=dtype)
text_embeddings = text_encoder_pipeline(prompt="a cat").get_by_kwargs("denoiser_input_fields")
# denoise and decode
pipeline.load_components(torch_dtype=dtype)
output = pipeline(
**text_embeddings,
num_inference_steps=4,
).images[0]
```
[`ComponentsManager`] handles memory across multiple pipelines. Unlike the offloading strategies in [`DiffusionPipeline`] that follow a fixed order, [`ComponentsManager`] makes offloading decisions dynamically each time a model forward pass runs, based on the current memory situation. This means it works regardless of how many pipelines you create or what order you run them in. See the [ComponentsManager](./components_manager) guide for more details.
If pipeline stages share components (e.g., the same VAE used for encoding and decoding), you can use [`~ModularPipeline.update_components`] to pass an already-loaded component to another pipeline instead of loading it again.
## Modular repository
A repository is required if the pipeline blocks use *pretrained components*. The repository supplies loading specifications and metadata.
[`ModularPipeline`] works with regular diffusers repositories out of the box. However, you can also create a *modular repository* for more flexibility. A modular repository contains a `modular_model_index.json` file containing the following 3 elements.
[`ModularPipeline`] specifically requires *modular repositories* (see [example repository](https://huggingface.co/YiYiXu/modular-diffdiff)) which are more flexible than a typical repository. It contains a `modular_model_index.json` file containing the following 3 elements.
- `library` and `class` shows which library the component was loaded from and its class. If `null`, the component hasn't been loaded yet.
- `library` and `class` shows which library the component was loaded from and it's class. If `null`, the component hasn't been loaded yet.
- `loading_specs_dict` contains the information required to load the component such as the repository and subfolder it is loaded from.
The key advantage of a modular repository is that components can be loaded from different repositories. For example, [diffusers/flux2-bnb-4bit-modular](https://huggingface.co/diffusers/flux2-bnb-4bit-modular) loads a quantized transformer from `diffusers/FLUX.2-dev-bnb-4bit` while loading the remaining components from `black-forest-labs/FLUX.2-dev`.
Unlike standard repositories, a modular repository can fetch components from different repositories based on the `loading_specs_dict`. Components don't need to exist in the same repository.
To convert a regular diffusers repository into a modular one, create the pipeline using the regular repository, and then push to the Hub. The saved repository will contain a `modular_model_index.json` with all the loading specifications.
```py
from diffusers import ModularPipeline
# load from a regular repo
pipeline = ModularPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
# push as a modular repository
pipeline.save_pretrained("local/path", repo_id="my-username/sdxl-modular", push_to_hub=True)
```
A modular repository can also include custom pipeline blocks as Python code. This allows you to share specialized blocks that aren't native to Diffusers. For example, [diffusers/Florence2-image-Annotator](https://huggingface.co/diffusers/Florence2-image-Annotator) contains custom blocks alongside the loading configuration:
A modular repository may contain custom code for loading a [`ModularPipeline`]. This allows you to use specialized blocks that aren't native to Diffusers.
```
Florence2-image-Annotator/
modular-diffdiff-0704/
├── block.py # Custom pipeline blocks implementation
├── config.json # Pipeline configuration and auto_map
├── mellon_config.json # UI configuration for Mellon
└── modular_model_index.json # Component loading specifications
```
The `config.json` file contains an `auto_map` key that tells [`ModularPipeline`] where to find the custom blocks:
The [config.json](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file contains an `auto_map` key that points to where a custom block is defined in `block.py`.
```json
{
"_class_name": "Florence2AnnotatorBlocks",
"_class_name": "DiffDiffBlocks",
"auto_map": {
"ModularPipelineBlocks": "block.Florence2AnnotatorBlocks"
"ModularPipelineBlocks": "block.DiffDiffBlocks"
}
}
```
Load custom code repositories with `trust_remote_code=True` as shown in [from_pretrained](#from_pretrained). See [Custom blocks](./custom_blocks) for how to create and share your own.

View File

@@ -25,42 +25,56 @@ This guide will show you how to create a [`~modular_pipelines.ModularPipelineBlo
A [`~modular_pipelines.ModularPipelineBlocks`] requires `inputs`, and `intermediate_outputs`.
- `inputs` are values a block reads from the [`~modular_pipelines.PipelineState`] to perform its computation. These can be values provided by a user (like a prompt or image) or values produced by a previous block (like encoded `image_latents`).
- `inputs` are values provided by a user and retrieved from the [`~modular_pipelines.PipelineState`]. This is useful because some workflows resize an image, but the original image is still required. The [`~modular_pipelines.PipelineState`] maintains the original image.
Use `InputParam` to define `inputs`.
```py
class ImageEncodeStep(ModularPipelineBlocks):
...
```py
from diffusers.modular_pipelines import InputParam
@property
def inputs(self):
return [
InputParam(name="image", type_hint="PIL.Image", required=True, description="raw input image to process"),
]
...
```
user_inputs = [
InputParam(name="image", type_hint="PIL.Image", description="raw input image to process")
]
```
- `intermediate_outputs` are new values created by a block and added to the [`~modular_pipelines.PipelineState`]. The `intermediate_outputs` are available as `inputs` for subsequent blocks or available as the final output from running the pipeline.
Use `OutputParam` to define `intermediate_outputs`.
```py
class ImageEncodeStep(ModularPipelineBlocks):
...
```py
from diffusers.modular_pipelines import OutputParam
@property
def intermediate_outputs(self):
return [
OutputParam(name="image_latents", description="latents representing the image"),
]
...
```
user_intermediate_outputs = [
OutputParam(name="image_latents", description="latents representing the image")
]
```
The intermediate inputs and outputs share data to connect blocks. They are accessible at any point, allowing you to track the workflow's progress.
## Components and configs
## Computation logic
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
```py
def __call__(self, components, state):
# Get a local view of the state variables this block needs
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
self.set_block_state(state, block_state)
return components, state
```
### Components and configs
The components and pipeline-level configs a block needs are specified in [`ComponentSpec`] and [`~modular_pipelines.ConfigSpec`].
@@ -68,108 +82,24 @@ The components and pipeline-level configs a block needs are specified in [`Compo
- [`~modular_pipelines.ConfigSpec`] contains pipeline-level settings that control behavior across all blocks.
```py
class ImageEncodeStep(ModularPipelineBlocks):
...
from diffusers import ComponentSpec, ConfigSpec
@property
def expected_components(self):
return [
ComponentSpec(name="vae", type_hint=AutoencoderKL),
]
expected_components = [
ComponentSpec(name="unet", type_hint=UNet2DConditionModel),
ComponentSpec(name="scheduler", type_hint=EulerDiscreteScheduler)
]
@property
def expected_configs(self):
return [
ConfigSpec("force_zeros_for_empty_prompt", True),
]
...
expected_config = [
ConfigSpec("force_zeros_for_empty_prompt", True)
]
```
When the blocks are converted into a pipeline, the components become available to the block as the first argument in `__call__`.
## Computation logic
The computation a block performs is defined in the `__call__` method and it follows a specific structure.
1. Retrieve the [`~modular_pipelines.BlockState`] to get a local view of the `inputs`.
2. Implement the computation logic on the `inputs`.
3. Update [`~modular_pipelines.PipelineState`] to push changes from the local [`~modular_pipelines.BlockState`] back to the global [`~modular_pipelines.PipelineState`].
4. Return the components and state which becomes available to the next block.
```py
class ImageEncodeStep(ModularPipelineBlocks):
def __call__(self, components, state):
# Get a local view of the state variables this block needs
block_state = self.get_block_state(state)
# Your computation logic here
# block_state contains all your inputs
# Access them like: block_state.image, block_state.processed_image
# Update the pipeline state with your updated block_states
self.set_block_state(state, block_state)
return components, state
def __call__(self, components, state):
# Access components using dot notation
unet = components.unet
vae = components.vae
scheduler = components.scheduler
```
## Putting it all together
Here is the complete block with all the pieces connected.
```py
from diffusers import ComponentSpec, AutoencoderKL
from diffusers.modular_pipelines import InputParam, ModularPipelineBlocks, OutputParam
class ImageEncodeStep(ModularPipelineBlocks):
@property
def description(self):
return "Encode an image into latent space."
@property
def expected_components(self):
return [
ComponentSpec(name="vae", type_hint=AutoencoderKL),
]
@property
def inputs(self):
return [
InputParam(name="image", type_hint="PIL.Image", required=True, description="raw input image to process"),
]
@property
def intermediate_outputs(self):
return [
OutputParam(name="image_latents", type_hint="torch.Tensor", description="latents representing the image"),
]
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.image_latents = components.vae.encode(block_state.image)
self.set_block_state(state, block_state)
return components, state
```
Every block has a `doc` property that is automatically generated from the properties you defined above. It provides a summary of the block's description, components, inputs, and outputs.
```py
block = ImageEncoderStep()
print(block.doc)
class ImageEncodeStep
Encode an image into latent space.
Components:
vae (`AutoencoderKL`)
Inputs:
image (`PIL.Image`):
raw input image to process
Outputs:
image_latents (`torch.Tensor`):
latents representing the image
```

View File

@@ -39,44 +39,17 @@ image
[`~ModularPipeline.from_pretrained`] uses lazy loading - it reads the configuration to learn where to load each component from, but doesn't actually load the model weights until you call [`~ModularPipeline.load_components`]. This gives you control over when and how components are loaded.
> [!TIP]
> `ComponentsManager` with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
>
> If you don't need offloading, remove the `components_manager` argument and move the pipeline to your device manually with `to("cuda")`.
> [`ComponentsManager`] with `enable_auto_cpu_offload` automatically moves models between CPU and GPU as needed, reducing memory usage for large models like Qwen-Image. Learn more in the [ComponentsManager](./components_manager) guide.
Learn more about creating and loading pipelines in the [Creating a pipeline](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#creating-a-pipeline) and [Loading components](https://huggingface.co/docs/diffusers/modular_diffusers/modular_pipeline#loading-components) guides.
## Understand the structure
A [`ModularPipeline`] has two parts: a **definition** (the blocks) and a **state** (the loaded components and configs).
A [`ModularPipeline`] has two parts:
- **State**: the loaded components (models, schedulers, processors) and configuration
- **Definition**: the [`ModularPipelineBlocks`] that specify inputs, outputs, expected components and computation logic
Print the pipeline to see its state — the components and their loading status and configuration.
```py
print(pipe)
```
```
QwenImageModularPipeline {
"_blocks_class_name": "QwenImageAutoBlocks",
"_class_name": "QwenImageModularPipeline",
"_diffusers_version": "0.37.0.dev0",
"transformer": [
"diffusers",
"QwenImageTransformer2DModel",
{
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
"revision": null,
"subfolder": "transformer",
"type_hint": [
"diffusers",
"QwenImageTransformer2DModel"
],
"variant": null
}
],
...
}
```
Access the definition through `pipe.blocks` — this is the [`~modular_pipelines.ModularPipelineBlocks`] that defines the pipeline's workflows, inputs, outputs, and computation logic.
The blocks define *what* the pipeline does. Access them through `pipe.blocks`.
```py
print(pipe.blocks)
```
@@ -114,8 +87,7 @@ The output returns:
### Workflows
This pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Learn more about how this works under the hood in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
`QwenImageAutoBlocks` is a [`ConditionalPipelineBlocks`], so this pipeline supports multiple workflows and adapts its behavior based on the inputs you provide. For example, if you pass `image` to the pipeline, it runs an image-to-image workflow instead of text-to-image. Let's see this in action with an example.
```py
from diffusers.utils import load_image
@@ -127,21 +99,20 @@ image = pipe(
).images[0]
```
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow. This is useful when you want to customize or debug a specific workflow. You can check `pipe.blocks.available_workflows` to see all available workflows.
Use `get_workflow()` to extract the blocks for a specific workflow. Pass the workflow name (e.g., `"image2image"`, `"inpainting"`, `"controlnet_text2image"`) to get only the blocks relevant to that workflow.
```py
img2img_blocks = pipe.blocks.get_workflow("image2image")
```
Conditional blocks are convenient for users, but their conditional logic adds complexity when customizing or debugging. Extracting a workflow gives you the specific blocks relevant to your workflow, making it easier to work with. Learn more in the [AutoPipelineBlocks](https://huggingface.co/docs/diffusers/modular_diffusers/auto_pipeline_blocks) guide.
### Sub-blocks
Blocks can contain other blocks. `pipe.blocks` gives you the top-level block definition (here, `QwenImageAutoBlocks`), while `sub_blocks` lets you access the smaller blocks inside it.
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`.
`QwenImageAutoBlocks` is composed of: `text_encoder`, `vae_encoder`, `controlnet_vae_encoder`, `denoise`, and `decode`. Access them through the `sub_blocks` property.
These sub-blocks run one after another and data flows linearly from one block to the next — each block's `intermediate_outputs` become available as `inputs` to the next block. This is how [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) work.
You can access them through the `sub_blocks` property. The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
The `doc` property is useful for seeing the full documentation of any block, including its inputs, outputs, and components.
```py
vae_encoder_block = pipe.blocks.sub_blocks["vae_encoder"]
print(vae_encoder_block.doc)
@@ -194,7 +165,7 @@ class CannyBlock
Canny map for input image
```
Use `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
UUse `get_workflow` to extract the ControlNet workflow from [`QwenImageAutoBlocks`].
```py
# Get the controlnet workflow that we want to work with
blocks = pipe.blocks.get_workflow("controlnet_text2image")
@@ -211,8 +182,9 @@ class SequentialPipelineBlocks
...
```
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) - a multi-block type where blocks run one after another and data flows linearly from one block to the next. Each block's `intermediate_outputs` become available as `inputs` to subsequent blocks.
The extracted workflow is a [`SequentialPipelineBlocks`](./sequential_pipeline_blocks) and it currently requires `control_image` as input. Insert the canny block at the beginning so the pipeline accepts a regular image instead.
Currently this workflow requires `control_image` as input. Let's insert the canny block at the beginning so the pipeline accepts a regular image instead.
```py
# Insert canny at the beginning
blocks.sub_blocks.insert("canny", canny_block, 0)
@@ -239,7 +211,7 @@ class SequentialPipelineBlocks
Now the pipeline takes `image` as input instead of `control_image`. Because blocks in a sequence share data automatically, the canny block's output (`control_image`) flows to the denoise block that needs it, and the canny block's input (`image`) becomes a pipeline input since no earlier block provides it.
Create a pipeline from the modified blocks and load a ControlNet model. The ControlNet isn't part of the original model repository, so load it separately and add it with [`~ModularPipeline.update_components`].
Create a pipeline from the modified blocks and load a ControlNet model.
```py
pipeline = blocks.init_pipeline("Qwen/Qwen-Image", components_manager=manager)
@@ -269,16 +241,6 @@ output
## Next steps
<hfoptions id="next">
<hfoption id="Learn the basics">
Understand the core building blocks of Modular Diffusers:
- [ModularPipelineBlocks](./pipeline_block): The basic unit for defining a step in a pipeline.
- [SequentialPipelineBlocks](./sequential_pipeline_blocks): Chain blocks to run in sequence.
- [AutoPipelineBlocks](./auto_pipeline_blocks): Create pipelines that support multiple workflows.
- [States](./modular_diffusers_states): How data is shared between blocks.
</hfoption>
<hfoption id="Build custom blocks">
Learn how to create your own blocks with custom logic in the [Building Custom Blocks](./custom_blocks) guide.

View File

@@ -91,42 +91,23 @@ class ImageEncoderBlock(ModularPipelineBlocks):
</hfoption>
</hfoptions>
Connect the two blocks by defining a [`~modular_pipelines.SequentialPipelineBlocks`]. List the block instances in `block_classes` and their corresponding names in `block_names`. The blocks are executed in the order they appear in `block_classes`, and data flows from one block to the next through [`~modular_pipelines.PipelineState`].
Connect the two blocks by defining an [`InsertableDict`] to map the block names to the block instances. Blocks are executed in the order they're registered in `blocks_dict`.
Use [`~modular_pipelines.SequentialPipelineBlocks.from_blocks_dict`] to create a [`~modular_pipelines.SequentialPipelineBlocks`].
```py
class ImageProcessingStep(SequentialPipelineBlocks):
"""
# auto_docstring
"""
model_name = "my_model"
block_classes = [InputBlock(), ImageEncoderBlock()]
block_names = ["input", "image_encoder"]
from diffusers.modular_pipelines import SequentialPipelineBlocks, InsertableDict
@property
def description(self):
return (
"Process text prompts and images for the pipeline. It:\n"
" - Determines the batch size from the prompts.\n"
" - Encodes the image into latent space."
)
blocks_dict = InsertableDict()
blocks_dict["input"] = input_block
blocks_dict["image_encoder"] = image_encoder_block
blocks = SequentialPipelineBlocks.from_blocks_dict(blocks_dict)
```
When you create a [`~modular_pipelines.SequentialPipelineBlocks`], properties like `inputs`, `intermediate_outputs`, and `expected_components` are automatically aggregated from the sub-blocks, so there is no need to define them again.
There are a few properties you should set:
- `description`: We recommend adding a description for the assembled block to explain what the combined step does.
- `model_name`: This is automatically derived from the sub-blocks but isn't always correct, so you may need to override it.
- `outputs`: By default this is the same as `intermediate_outputs`, but you can manually set it to control which values appear in the doc. This is useful for showing only the final outputs instead of all intermediate values.
These properties, together with the aggregated `inputs`, `intermediate_outputs`, and `expected_components`, are used to automatically generate the `doc` property.
Print the `ImageProcessingStep` block to inspect its sub-blocks, and use `doc` for a full summary of the block's inputs, outputs, and components.
Inspect the sub-blocks in [`~modular_pipelines.SequentialPipelineBlocks`] by calling `blocks`, and for more details about the inputs and outputs, access the `docs` attribute.
```py
blocks = ImageProcessingStep()
print(blocks)
print(blocks.doc)
```
```

View File

@@ -111,7 +111,7 @@ if __name__ == "__main__":
Call `torchrun` to run the inference script and use the `--nproc_per_node` argument to set the number of GPUs to use.
```bash
torchrun --nproc_per_node=2 run_distributed.py
torchrun run_distributed.py --nproc_per_node=2
```
## device_map

View File

@@ -97,32 +97,5 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
> )
> ```
### Saving custom models
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
```py
# my_model.py
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin):
...
MyCustomModel.register_for_auto_class("AutoModel")
model = MyCustomModel(...)
model.save_pretrained("./my_model")
```
The saved `config.json` will include the `auto_map` field.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8)
<tr>
<th style="text-align: center;">Face Image</th>
<th style="text-align: center;">Video</th>
<th style="text-align: center;">Description</th>
<th style="text-align: center;">Description</th
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true" style="height: auto; width: 600px;"></td>

View File

@@ -1,133 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Helios
[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:
- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.
- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.
- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.
This guide will walk you through using Helios for use cases.
## Load Model Checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```python
import torch
from diffusers import HeliosPipeline, HeliosPyramidPipeline
from huggingface_hub import snapshot_download
# For Best Quality
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Intermediate Weight
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# For Best Efficiency
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
pipe.to("cuda")
```
## Text-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Image-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Image</th>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Interactive-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Resources
Learn more about Helios with the following resources.
- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.

View File

@@ -132,8 +132,6 @@
sections:
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/helios
title: Helios
- title: Resources
isExpanded: false

View File

@@ -26,14 +26,6 @@ http://www.apache.org/licenses/LICENSE-2.0
<th>项目名称</th>
<th>描述</th>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/PKU-YuanGroup/Helios"> helios </a></td>
<td>Helios比1.3B更低开销、更快且更强的14B的实时长视频生成模型</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/PKU-YuanGroup/ConsisID"> consisid </a></td>
<td>ConsisID零样本身份保持的文本到视频生成模型</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
<td>Stable Diffusion内置到Blender</td>

View File

@@ -86,7 +86,10 @@ t2i_pipeline.guider
## 更改引导器参数
引导器参数可以通过 [`~ComponentSpec.create`] 方法以及 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。
引导器参数可以通过 [`~ComponentSpec.create`] 方法 [`~ModularPipeline.update_components`] 方法进行调整。下面的示例更改了 `guidance_scale` 值。
<hfoptions id="switch">
<hfoption id="create">
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
@@ -94,6 +97,18 @@ guider = guider_spec.create(guidance_scale=10)
t2i_pipeline.update_components(guider=guider)
```
</hfoption>
<hfoption id="update_components">
```py
guider_spec = t2i_pipeline.get_component_spec("guider")
guider_spec.config["guidance_scale"] = 10
t2i_pipeline.update_components(guider=guider_spec)
```
</hfoption>
</hfoptions>
## 上传自定义引导器
在自定义引导器上调用 [`~utils.PushToHubMixin.push_to_hub`] 方法,将其分享到 Hub。

View File

@@ -1,134 +0,0 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Helios
[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时拥有媲美强大基线模型的生成质量并在统一架构下原生集成了文生视频T2V、图生视频I2V和视频生视频V2V任务。Helios 的主要特性包括:
- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样我们的模型即可生成高质量且高度连贯的分钟级视频。
- 无需标准的加速技术例如KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。
- 引入了多项优化方案在降低显存消耗的同时显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片sharding等基础设施即可使用与图像模型相当的批大小batch sizes来训练 14B 的视频生成模型。
本指南将引导您完成 Helios 在不同场景下的使用。
## Load Model Checkpoints
模型权重可以存储在Hub上或本地的单独子文件夹中在这种情况下您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
```python
import torch
from diffusers import HeliosPipeline, HeliosPyramidPipeline
from huggingface_hub import snapshot_download
# For Best Quality
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Intermediate Weight
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# For Best Efficiency
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
pipe.to("cuda")
```
## Text-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Image-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Image</th>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Interactive-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Resources
通过以下资源了解有关 Helios 的更多信息:
- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。

View File

@@ -1232,49 +1232,22 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video):
def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
# Distribute video encoding across processes: each process only encodes its own shard
num_videos = len(train_dataset.instance_videos)
num_procs = accelerator.num_processes
local_rank = accelerator.process_index
local_count = len(range(local_rank, num_videos, num_procs))
progress_encode_bar = tqdm(
range(local_count),
desc="Encoding videos",
disable=not accelerator.is_local_main_process,
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
encoded_videos = [None] * num_videos
for i, video in enumerate(train_dataset.instance_videos):
if i % num_procs == local_rank:
encoded_videos[i] = encode_video(video)
progress_encode_bar.update(1)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
progress_encode_bar.close()
# Broadcast encoded latent distributions so every process has the full set
if num_procs > 1:
import torch.distributed as dist
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
ref_params = next(v for v in encoded_videos if v is not None).parameters
for i in range(num_videos):
src = i % num_procs
if encoded_videos[i] is not None:
params = encoded_videos[i].parameters.contiguous()
else:
params = torch.empty_like(ref_params)
dist.broadcast(params, src=src)
encoded_videos[i] = DiagonalGaussianDistribution(params)
train_dataset.instance_videos = encoded_videos
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
prompts = [example["instance_prompt"] for example in examples]

View File

@@ -17,9 +17,6 @@ import logging
import os
import sys
import tempfile
import unittest
from diffusers.utils import is_transformers_version
sys.path.append("..")
@@ -33,7 +30,6 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
class CustomDiffusion(ExamplesTestsAccelerate):
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:

View File

@@ -1,66 +0,0 @@
# Training AutoencoderRAE
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
It follows the same high-level training recipe as the official RAE stage-1 setup:
- frozen encoder
- train decoder
- pixel reconstruction loss
- optional encoder feature consistency loss
## Quickstart
### Resume or finetune from pretrained weights
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
### Train from scratch with a pretrained encoder
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--encoder_type dinov2 \
--encoder_name_or_path facebook/dinov2-with-registers-base \
--encoder_input_size 224 \
--patch_size 16 \
--image_size 256 \
--decoder_hidden_size 1152 \
--decoder_num_hidden_layers 28 \
--decoder_num_attention_heads 16 \
--decoder_intermediate_size 4096 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
Dataset format is expected to be `ImageFolder`-compatible:
```text
train_data_dir/
class_a/
img_0001.jpg
class_b/
img_0002.jpg
```

View File

@@ -1,405 +0,0 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm
from diffusers import AutoencoderRAE
from diffusers.optimization import get_scheduler
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
parser.add_argument(
"--train_data_dir",
type=str,
required=True,
help="Path to an ImageFolder-style dataset root.",
)
parser.add_argument(
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
)
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--center_crop", action="store_true")
parser.add_argument("--random_flip", action="store_true")
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--dataloader_num_workers", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=10)
parser.add_argument("--max_train_steps", type=int, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--checkpointing_steps", type=int, default=1000)
parser.add_argument("--validation_steps", type=int, default=500)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
)
parser.add_argument(
"--encoder_name_or_path",
type=str,
default=None,
help=(
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
),
)
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
parser.add_argument("--encoder_hidden_size", type=int, default=768)
parser.add_argument("--encoder_patch_size", type=int, default=14)
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
parser.add_argument("--noise_tau", type=float, default=0.0)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument(
"--reconstruction_loss_type",
type=str,
choices=["l1", "mse"],
default="l1",
help="Pixel reconstruction loss.",
)
parser.add_argument(
"--encoder_loss_weight",
type=float,
default=0.0,
help="Weight for encoder feature consistency loss in the training loop.",
)
parser.add_argument(
"--use_encoder_loss",
action="store_true",
help="Enable encoder feature consistency loss term in the training loop.",
)
parser.add_argument("--report_to", type=str, default="tensorboard")
return parser.parse_args()
def build_transforms(args):
image_transforms = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
if args.random_flip:
image_transforms.append(transforms.RandomHorizontalFlip())
image_transforms.append(transforms.ToTensor())
return transforms.Compose(image_transforms)
def compute_losses(
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
):
decoded = model(pixel_values).sample
if decoded.shape[-2:] != pixel_values.shape[-2:]:
raise ValueError(
"Training requires matching reconstruction and target sizes, got "
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
)
if reconstruction_loss_type == "l1":
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
else:
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
encoder_loss = torch.zeros_like(reconstruction_loss)
if use_encoder_loss and encoder_loss_weight > 0:
base_model = model.module if hasattr(model, "module") else model
target_encoder_input = base_model._resize_and_normalize(pixel_values)
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
encoder_forward_kwargs = {"model": base_model.encoder}
if base_model.config.encoder_type == "mae":
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
with torch.no_grad():
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
reconstructed_tokens = base_model._encoder_forward_fn(
images=reconstructed_encoder_input, **encoder_forward_kwargs
)
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
return decoded, loss, reconstruction_loss, encoder_loss
def _strip_final_layernorm_affine(state_dict, prefix=""):
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
"""Load pretrained HF transformers encoder weights into the model's encoder."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
model.encoder.load_state_dict(state_dict, strict=False)
def main():
args = parse_args()
if args.resolution != args.image_size:
raise ValueError(
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
"for stage-1 reconstruction loss."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
project_config=accelerator_project_config,
log_with=args.report_to,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples]).float()
return {"pixel_values": pixel_values}
train_dataloader = DataLoader(
dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=True,
drop_last=True,
)
if args.pretrained_model_name_or_path is not None:
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
else:
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=args.encoder_hidden_size,
encoder_patch_size=args.encoder_patch_size,
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
decoder_hidden_size=args.decoder_hidden_size,
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
decoder_num_attention_heads=args.decoder_num_attention_heads,
decoder_intermediate_size=args.decoder_intermediate_size,
patch_size=args.patch_size,
encoder_input_size=args.encoder_input_size,
image_size=args.image_size,
num_channels=args.num_channels,
noise_tau=args.noise_tau,
reshape_to_2d=args.reshape_to_2d,
use_encoder_loss=args.use_encoder_loss,
scaling_factor=args.scaling_factor,
)
if args.encoder_name_or_path is not None:
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
model.encoder.requires_grad_(False)
model.decoder.requires_grad_(True)
model.train()
optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if overrode_max_train_steps:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if accelerator.is_main_process:
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
pixel_values = batch["pixel_values"]
_, loss, reconstruction_loss, encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {
"loss": loss.detach().item(),
"reconstruction_loss": reconstruction_loss.detach().item(),
"encoder_loss": encoder_loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step % args.validation_steps == 0:
with torch.no_grad():
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.log(
{
"val/loss": val_loss.detach().item(),
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
"val/encoder_loss": val_encoder_loss.detach().item(),
},
step=global_step,
)
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_path)
logger.info(f"Saved checkpoint to {save_path}")
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -94,15 +94,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/pipeline \
--output_path converted/transfer/2b/general/depth \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/depth/models
# edge
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
@@ -126,15 +120,9 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/pipeline \
--output_path converted/transfer/2b/general/blur \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/blur/models
# seg
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
@@ -142,14 +130,8 @@ python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/pipeline \
--output_path converted/transfer/2b/general/seg \
--save_pipeline
python scripts/convert_cosmos_to_diffusers.py \
--transformer_type Cosmos-2.5-Transfer-General-2B \
--transformer_ckpt_path $transformer_ckpt_path \
--vae_type wan2.1 \
--output_path converted/transfer/2b/general/seg/models
```
"""

View File

@@ -1,403 +0,0 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi, hf_hub_download
from diffusers import AutoencoderRAE
DECODER_CONFIGS = {
"ViTB": {
"decoder_hidden_size": 768,
"decoder_intermediate_size": 3072,
"decoder_num_attention_heads": 12,
"decoder_num_hidden_layers": 12,
},
"ViTL": {
"decoder_hidden_size": 1024,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 24,
},
"ViTXL": {
"decoder_hidden_size": 1152,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 28,
},
}
ENCODER_DEFAULT_NAME_OR_PATH = {
"dinov2": "facebook/dinov2-with-registers-base",
"siglip2": "google/siglip2-base-patch16-256",
"mae": "facebook/vit-mae-base",
}
ENCODER_HIDDEN_SIZE = {
"dinov2": 768,
"siglip2": 768,
"mae": 768,
}
ENCODER_PATCH_SIZE = {
"dinov2": 14,
"siglip2": 16,
"mae": 16,
}
DEFAULT_DECODER_SUBDIR = {
"dinov2": "decoders/dinov2/wReg_base",
"mae": "decoders/mae/base_p16",
"siglip2": "decoders/siglip2/base_p16_i256",
}
DEFAULT_STATS_SUBDIR = {
"dinov2": "stats/dinov2/wReg_base",
"mae": "stats/mae/base_p16",
"siglip2": "stats/siglip2/base_p16_i256",
}
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
STATS_FILE_CANDIDATES = ("stat.pt",)
def dataset_case_candidates(name: str) -> tuple[str, ...]:
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
class RepoAccessor:
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
self.repo_or_path = repo_or_path
self.cache_dir = cache_dir
self.local_root: Path | None = None
self.repo_id: str | None = None
self.repo_files: set[str] | None = None
root = Path(repo_or_path)
if root.exists() and root.is_dir():
self.local_root = root
else:
self.repo_id = repo_or_path
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
def exists(self, relative_path: str) -> bool:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return (self.local_root / relative_path).is_file()
return relative_path in self.repo_files
def fetch(self, relative_path: str) -> Path:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return self.local_root / relative_path
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
return Path(downloaded)
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
state_dict = maybe_wrapped
for k in ("model", "module", "state_dict"):
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
state_dict = state_dict[k]
out = dict(state_dict)
if len(out) > 0 and all(key.startswith("module.") for key in out):
out = {key[len("module.") :]: value for key, value in out.items()}
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
out = {key[len("decoder.") :]: value for key, value in out.items()}
return out
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
Example mappings:
- `...attention.attention.query.*` -> `...attention.to_q.*`
- `...attention.attention.key.*` -> `...attention.to_k.*`
- `...attention.attention.value.*` -> `...attention.to_v.*`
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
"""
remapped: dict[str, Any] = {}
for key, value in state_dict.items():
new_key = key
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
remapped[new_key] = value
return remapped
def resolve_decoder_file(
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
) -> str:
if decoder_checkpoint is not None:
if accessor.exists(decoder_checkpoint):
return decoder_checkpoint
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
for name in DECODER_FILE_CANDIDATES:
candidate = f"{base}/{name}"
if accessor.exists(candidate):
return candidate
raise FileNotFoundError(
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
)
def resolve_stats_file(
accessor: RepoAccessor,
encoder_type: str,
dataset_name: str,
stats_checkpoint: str | None,
) -> str | None:
if stats_checkpoint is not None:
if accessor.exists(stats_checkpoint):
return stats_checkpoint
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
base = DEFAULT_STATS_SUBDIR[encoder_type]
for dataset in dataset_case_candidates(dataset_name):
for name in STATS_FILE_CANDIDATES:
candidate = f"{base}/{dataset}/{name}"
if accessor.exists(candidate):
return candidate
return None
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
if not isinstance(stats_obj, dict):
return None, None
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
mean = stats_obj.get("mean", None)
var = stats_obj.get("var", None)
if mean is None and var is None:
return None, None
latents_std = None
if var is not None:
if isinstance(var, torch.Tensor):
latents_std = torch.sqrt(var + 1e-5)
else:
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
return mean, latents_std
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
"""Remove final layernorm weight/bias from encoder state dict.
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
Stripping these keys means the model keeps its default init values, which
is functionally equivalent to setting elementwise_affine=False.
"""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
"""Download the HF encoder and extract the state dict for the inner model."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
# SiglipModel.vision_model is a SiglipVisionTransformer.
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
# under .vision_model, so we add the prefix to match the diffusers key layout.
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
def convert(args: argparse.Namespace) -> None:
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
print(f"Using decoder checkpoint: {decoder_relpath}")
if stats_relpath is not None:
print(f"Using stats checkpoint: {stats_relpath}")
else:
print("No stats checkpoint found; conversion will proceed without latent stats.")
if args.dry_run:
return
decoder_path = accessor.fetch(decoder_relpath)
decoder_obj = torch.load(decoder_path, map_location="cpu")
decoder_state_dict = unwrap_state_dict(decoder_obj)
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
latents_mean, latents_std = None, None
if stats_relpath is not None:
stats_path = accessor.fetch(stats_relpath)
stats_obj = torch.load(stats_path, map_location="cpu")
latents_mean, latents_std = extract_latent_stats(stats_obj)
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
from transformers import AutoConfig, AutoImageProcessor
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
encoder_norm_mean = list(proc.image_mean)
encoder_norm_std = list(proc.image_std)
# Read encoder hidden size and patch size from HF config
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
try:
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
# For models like SigLIP that nest vision config
if hasattr(hf_config, "vision_config"):
hf_config = hf_config.vision_config
encoder_hidden_size = hf_config.hidden_size
encoder_patch_size = hf_config.patch_size
except Exception:
pass
# Load the actual encoder weights from HF to include in the saved model
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
# Build model on meta device to avoid double init overhead
with torch.device("meta"):
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=encoder_hidden_size,
encoder_patch_size=encoder_patch_size,
encoder_input_size=args.encoder_input_size,
patch_size=args.patch_size,
image_size=args.image_size,
num_channels=args.num_channels,
encoder_norm_mean=encoder_norm_mean,
encoder_norm_std=encoder_norm_std,
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
latents_mean=latents_mean,
latents_std=latents_std,
scaling_factor=args.scaling_factor,
)
# Assemble full state dict and load with assign=True
full_state_dict = {}
# Encoder weights (prefixed with "encoder.")
for k, v in encoder_state_dict.items():
full_state_dict[f"encoder.{k}"] = v
# Decoder weights (prefixed with "decoder.")
for k, v in decoder_state_dict.items():
full_state_dict[f"decoder.{k}"] = v
# Buffers from config
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
if latents_mean is not None:
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
full_state_dict["_latents_mean"] = latents_mean_t
else:
full_state_dict["_latents_mean"] = torch.zeros(1)
if latents_std is not None:
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
full_state_dict["_latents_std"] = latents_std_t
else:
full_state_dict["_latents_std"] = torch.ones(1)
model.load_state_dict(full_state_dict, strict=False, assign=True)
# Verify no critical keys are missing
model_keys = {name for name, _ in model.named_parameters()}
model_keys |= {name for name, _ in model.named_buffers()}
loaded_keys = set(full_state_dict.keys())
missing = model_keys - loaded_keys
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
if missing - allowed_missing:
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_path)
if args.verify_load:
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
if not isinstance(loaded_model, AutoencoderRAE):
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
print("Verification passed.")
print(f"Saved converted AutoencoderRAE to: {output_path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
parser.add_argument(
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
)
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
parser.add_argument(
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
)
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
parser.add_argument(
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
)
parser.add_argument(
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
)
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=None)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
parser.add_argument(
"--verify_load",
action="store_true",
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
convert(args)
if __name__ == "__main__":
main()

View File

@@ -101,7 +101,6 @@ _deps = [
"datasets",
"filelock",
"flax>=0.4.1",
"ftfy",
"hf-doc-builder>=0.3.0",
"httpx<1.0.0",
"huggingface-hub>=0.34.0,<2.0",
@@ -112,6 +111,7 @@ _deps = [
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
"k-diffusion==0.0.12",
"torchsde",
"note_seq",
"librosa",
@@ -222,14 +222,13 @@ extras["docs"] = deps_list("hf-doc-builder")
extras["training"] = deps_list("accelerate", "datasets", "protobuf", "tensorboard", "Jinja2", "peft", "timm")
extras["test"] = deps_list(
"compel",
"ftfy",
"GitPython",
"datasets",
"Jinja2",
"invisible-watermark",
"k-diffusion",
"librosa",
"parameterized",
"protobuf",
"pytest",
"pytest-timeout",
"pytest-xdist",
@@ -238,7 +237,6 @@ extras["test"] = deps_list(
"sentencepiece",
"scipy",
"tiktoken",
"torchsde",
"torchvision",
"transformers",
"phonemizer",

View File

@@ -10,6 +10,7 @@ from .utils import (
is_bitsandbytes_available,
is_flax_available,
is_gguf_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_nvidia_modelopt_available,
@@ -49,6 +50,8 @@ _import_structure = {
"is_flax_available",
"is_inflect_available",
"is_invisible_watermark_available",
"is_k_diffusion_available",
"is_k_diffusion_version",
"is_librosa_available",
"is_note_seq_available",
"is_onnx_available",
@@ -202,7 +205,6 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderRAE",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
@@ -228,7 +230,6 @@ else:
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HeliosTransformer3DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -361,8 +362,6 @@ else:
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"FlowMatchLCMScheduler",
"HeliosDMDScheduler",
"HeliosScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -519,8 +518,6 @@ else:
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HeliosPipeline",
"HeliosPyramidPipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -734,6 +731,19 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["pipelines"].extend(["ConsisIDPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torch_and_transformers_and_k_diffusion_objects # noqa F403
_import_structure["utils.dummy_torch_and_transformers_and_k_diffusion_objects"] = [
name for name in dir(dummy_torch_and_transformers_and_k_diffusion_objects) if not name.startswith("_")
]
else:
_import_structure["pipelines"].extend(["StableDiffusionKDiffusionPipeline", "StableDiffusionXLKDiffusionPipeline"])
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
@@ -975,7 +985,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
@@ -1001,7 +1010,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1130,8 +1138,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
FlowMatchLCMScheduler,
HeliosDMDScheduler,
HeliosScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
@@ -1267,8 +1273,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HeliosPipeline,
HeliosPyramidPipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,
@@ -1465,6 +1469,14 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ZImagePipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_k_diffusion_objects import * # noqa F403
else:
from .pipelines import StableDiffusionKDiffusionPipeline, StableDiffusionXLKDiffusionPipeline
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()

View File

@@ -89,6 +89,8 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -107,38 +107,6 @@ class ConfigMixin:
has_compatibles = False
_deprecated_kwargs = []
_auto_class = None
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
trust_remote_code=True)`.
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
to this class's module and class name.
Args:
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
Currently only `"AutoModel"` is supported.
Example:
```python
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin): ...
MyCustomModel.register_for_auto_class("AutoModel")
```
"""
if auto_class != "AutoModel":
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
cls._auto_class = auto_class
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -653,12 +621,6 @@ class ConfigMixin:
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)
if getattr(self, "_auto_class", None) is not None:
module = self.__class__.__module__.split(".")[-1]
auto_map = config_dict.get("auto_map", {})
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
config_dict["auto_map"] = auto_map
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str | os.PathLike):

View File

@@ -8,7 +8,6 @@ deps = {
"datasets": "datasets",
"filelock": "filelock",
"flax": "flax>=0.4.1",
"ftfy": "ftfy",
"hf-doc-builder": "hf-doc-builder>=0.3.0",
"httpx": "httpx<1.0.0",
"huggingface-hub": "huggingface-hub>=0.34.0,<2.0",
@@ -19,6 +18,7 @@ deps = {
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion==0.0.12",
"torchsde": "torchsde",
"note_seq": "note_seq",
"librosa": "librosa",

View File

@@ -54,7 +54,7 @@ class BaseGuidance(ConfigMixin, PushToHubMixin):
self._timestep: torch.LongTensor = None
self._count_prepared = 0
self._input_fields: dict[str, str | tuple[str, str]] = None
self._enabled = enabled
self._enabled = True
if not (0.0 <= start < 1.0):
raise ValueError(f"Expected `start` to be between 0.0 and 1.0, but got {start}.")

View File

@@ -48,7 +48,6 @@ _GO_LC_SUPPORTED_PYTORCH_LAYERS = (
torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d,
torch.nn.Linear,
torch.nn.Embedding,
# TODO(aryan): look into torch.nn.LayerNorm, torch.nn.GroupNorm later, seems to be causing some issues with CogVideoX
# because of double invocation of the same norm layer in CogVideoXLayerNorm
)

View File

@@ -307,17 +307,6 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
else:
# onload_self=False means this group relies on prefetching from a previous group.
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
# the prefetch chain may not cover them if they were absent during the first forward pass
# when the execution order was traced. In that case, their weights remain on offload_device,
# so we fall back to a synchronous onload here.
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
if params and params[0].device == self.group.offload_device:
self.group.onload_()
if self.group.stream is not None:
self.group.stream.synchronize()
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:

View File

@@ -78,7 +78,6 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HeliosLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
@@ -119,7 +118,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4LoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
HeliosLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,

View File

@@ -856,7 +856,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
)
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("text_encoders.t5xxl.transformer.")}
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_", "lora_te1_")) for k in state_dict)
has_diffb = any("diff_b" in k and k.startswith(("lora_unet_", "lora_te_")) for k in state_dict)
if has_diffb:
zero_status_diff_b = state_dict_all_zero(state_dict, ".diff_b")
if zero_status_diff_b:
@@ -895,7 +895,7 @@ def _convert_kohya_flux_lora_to_diffusers(state_dict):
state_dict = {
_custom_replace(k, limit_substrings): v
for k, v in state_dict.items()
if k.startswith(("lora_unet_", "lora_te_", "lora_te1_"))
if k.startswith(("lora_unet_", "lora_te_"))
}
if any("text_projection" in k for k in state_dict):
@@ -2519,13 +2519,6 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
# Normalize ZImage-specific dot-separated module names to underscore form so they
# match the diffusers model parameter names (context_refiner, noise_refiner).
state_dict = {
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
for k, v in state_dict.items()
}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
@@ -2536,18 +2529,19 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
if has_non_diffusers_lora_id:
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
@@ -2560,69 +2554,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
# Already in diffusers format (lora_A/lora_B), just pop
elif has_diffusers_lora_id:
for k in all_keys:
if k.endswith(a_key):
diffusers_up_key = k.replace(a_key, b_key)
alpha_key = k.replace(a_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(diffusers_up_key)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[k] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
# lora weight names and also include redundant keys:
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
lora_dot_down_key = ".lora.down.weight"
lora_dot_up_key = ".lora.up.weight"
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
if has_lora_dot_format:
dot_keys = list(state_dict.keys())
for k in dot_keys:
if lora_dot_down_key not in k:
continue
if k not in state_dict:
continue # already popped by a prior iteration
base = k[: -len(lora_dot_down_key)]
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
if base.endswith(".qkv"):
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
state_dict.pop(base + ".alpha", None)
continue
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
if re.search(r"\.out$", base) and ".to_out" not in base:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
continue
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
norm_k = re.sub(
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
r".to_\1" + lora_dot_down_key,
k,
)
norm_base = norm_k[: -len(lora_dot_down_key)]
alpha_key = norm_base + ".alpha"
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down] = down_weight * scale_down
converted_state_dict[diffusers_up] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

View File

@@ -3440,207 +3440,6 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class HeliosLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
elif any(k.startswith("lora_unet_") for k in state_dict):
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
adapter_name: str | None = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: dict | None = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: list[str] | None = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].
@@ -5673,10 +5472,6 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}
is_ai_toolkit = any(k.startswith("diffusion_model.") for k in state_dict)
if is_ai_toolkit:
state_dict = _convert_non_diffusers_flux2_lora_to_diffusers(state_dict)

View File

@@ -51,7 +51,6 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,

View File

@@ -22,12 +22,7 @@ from tokenizers import Tokenizer as TokenizerFast
from torch import nn
from ..models.modeling_utils import load_state_dict
from ..utils import (
_get_model_file,
is_accelerate_available,
is_transformers_available,
logging,
)
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging
if is_transformers_available():

View File

@@ -49,7 +49,6 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
@@ -101,7 +100,6 @@ if is_torch_available():
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -169,7 +167,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
@@ -215,7 +212,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -62,8 +62,6 @@ _REQUIRED_FLEX_VERSION = "2.5.0"
_REQUIRED_XLA_VERSION = "2.2"
_REQUIRED_XFORMERS_VERSION = "0.0.29"
logger = get_logger(__name__) # pylint: disable=invalid-name
_CAN_USE_FLASH_ATTN = is_flash_attn_available() and is_flash_attn_version(">=", _REQUIRED_FLASH_VERSION)
_CAN_USE_FLASH_ATTN_3 = is_flash_attn_3_available()
_CAN_USE_AITER_ATTN = is_aiter_available() and is_aiter_version(">=", _REQUIRED_AITER_VERSION)
@@ -75,18 +73,8 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
if _CAN_USE_FLASH_ATTN:
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
except (ImportError, OSError, RuntimeError) as e:
# Handle ABI mismatch or other import failures gracefully.
# This can happen when flash_attn was compiled against a different PyTorch version.
logger.warning(f"flash_attn is installed but failed to import: {e}. Falling back to native PyTorch attention.")
_CAN_USE_FLASH_ATTN = False
flash_attn_func = None
flash_attn_varlen_func = None
_wrapped_flash_attn_backward = None
_wrapped_flash_attn_forward = None
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
else:
flash_attn_func = None
flash_attn_varlen_func = None
@@ -95,47 +83,26 @@ else:
if _CAN_USE_FLASH_ATTN_3:
try:
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flash_attn_3 failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLASH_ATTN_3 = False
flash_attn_3_func = None
flash_attn_3_varlen_func = None
from flash_attn_interface import flash_attn_func as flash_attn_3_func
from flash_attn_interface import flash_attn_varlen_func as flash_attn_3_varlen_func
else:
flash_attn_3_func = None
flash_attn_3_varlen_func = None
if _CAN_USE_AITER_ATTN:
try:
from aiter import flash_attn_func as aiter_flash_attn_func
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"aiter failed to import: {e}. Falling back to native attention.")
_CAN_USE_AITER_ATTN = False
aiter_flash_attn_func = None
from aiter import flash_attn_func as aiter_flash_attn_func
else:
aiter_flash_attn_func = None
if _CAN_USE_SAGE_ATTN:
try:
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"sageattention failed to import: {e}. Falling back to native attention.")
_CAN_USE_SAGE_ATTN = False
sageattn = None
sageattn_qk_int8_pv_fp8_cuda = None
sageattn_qk_int8_pv_fp8_cuda_sm90 = None
sageattn_qk_int8_pv_fp16_cuda = None
sageattn_qk_int8_pv_fp16_triton = None
sageattn_varlen = None
from sageattention import (
sageattn,
sageattn_qk_int8_pv_fp8_cuda,
sageattn_qk_int8_pv_fp8_cuda_sm90,
sageattn_qk_int8_pv_fp16_cuda,
sageattn_qk_int8_pv_fp16_triton,
sageattn_varlen,
)
else:
sageattn = None
sageattn_qk_int8_pv_fp16_cuda = None
@@ -146,48 +113,26 @@ else:
if _CAN_USE_FLEX_ATTN:
try:
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"flex_attention failed to import: {e}. Falling back to native attention.")
_CAN_USE_FLEX_ATTN = False
flex_attention = None
else:
flex_attention = None
# We cannot import the flex_attention function from the package directly because it is expected (from the
# pytorch documentation) that the user may compile it. If we import directly, we will not have access to the
# compiled function.
import torch.nn.attention.flex_attention as flex_attention
if _CAN_USE_NPU_ATTN:
try:
from torch_npu import npu_fusion_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_npu failed to import: {e}. Falling back to native attention.")
_CAN_USE_NPU_ATTN = False
npu_fusion_attention = None
from torch_npu import npu_fusion_attention
else:
npu_fusion_attention = None
if _CAN_USE_XLA_ATTN:
try:
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"torch_xla failed to import: {e}. Falling back to native attention.")
_CAN_USE_XLA_ATTN = False
xla_flash_attention = None
from torch_xla.experimental.custom_kernel import flash_attention as xla_flash_attention
else:
xla_flash_attention = None
if _CAN_USE_XFORMERS_ATTN:
try:
import xformers.ops as xops
except (ImportError, OSError, RuntimeError) as e:
logger.warning(f"xformers failed to import: {e}. Falling back to native attention.")
_CAN_USE_XFORMERS_ATTN = False
xops = None
import xformers.ops as xops
else:
xops = None
@@ -213,6 +158,8 @@ else:
_register_fake = register_fake_no_op
logger = get_logger(__name__) # pylint: disable=invalid-name
# TODO(aryan): Add support for the following:
# - Sage Attention++
# - block sparse, radial and other attention methods
@@ -319,21 +266,13 @@ class _HubKernelConfig:
function_attr: str
revision: str | None = None
kernel_fn: Callable | None = None
wrapped_forward_attr: str | None = None
wrapped_backward_attr: str | None = None
wrapped_forward_fn: Callable | None = None
wrapped_backward_fn: Callable | None = None
# Registry for hub-based attention kernels
_HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# TODO: temporary revision for now. Remove when merged upstream into `main`.
AttentionBackendName._FLASH_3_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
function_attr="flash_attn_func",
revision="fake-ops-return-probs",
wrapped_forward_attr="flash_attn_interface._flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._flash_attn_backward",
repo_id="kernels-community/flash-attn3", function_attr="flash_attn_func", revision="fake-ops-return-probs"
),
AttentionBackendName._FLASH_3_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn3",
@@ -341,11 +280,7 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
# revision="fake-ops-return-probs",
),
AttentionBackendName.FLASH_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2",
function_attr="flash_attn_func",
revision=None,
wrapped_forward_attr="flash_attn_interface._wrapped_flash_attn_forward",
wrapped_backward_attr="flash_attn_interface._wrapped_flash_attn_backward",
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_func", revision=None
),
AttentionBackendName.FLASH_VARLEN_HUB: _HubKernelConfig(
repo_id="kernels-community/flash-attn2", function_attr="flash_attn_varlen_func", revision=None
@@ -670,39 +605,22 @@ def _flex_attention_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
# ===== Helpers for downloading kernels =====
def _resolve_kernel_attr(module, attr_path: str):
target = module
for attr in attr_path.split("."):
if not hasattr(target, attr):
raise AttributeError(f"Kernel module '{module.__name__}' does not define attribute path '{attr_path}'.")
target = getattr(target, attr)
return target
def _maybe_download_kernel_for_backend(backend: AttentionBackendName) -> None:
if backend not in _HUB_KERNELS_REGISTRY:
return
config = _HUB_KERNELS_REGISTRY[backend]
needs_kernel = config.kernel_fn is None
needs_wrapped_forward = config.wrapped_forward_attr is not None and config.wrapped_forward_fn is None
needs_wrapped_backward = config.wrapped_backward_attr is not None and config.wrapped_backward_fn is None
if not (needs_kernel or needs_wrapped_forward or needs_wrapped_backward):
if config.kernel_fn is not None:
return
try:
from kernels import get_kernel
kernel_module = get_kernel(config.repo_id, revision=config.revision)
if needs_kernel:
config.kernel_fn = _resolve_kernel_attr(kernel_module, config.function_attr)
kernel_func = getattr(kernel_module, config.function_attr)
if needs_wrapped_forward:
config.wrapped_forward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_forward_attr)
if needs_wrapped_backward:
config.wrapped_backward_fn = _resolve_kernel_attr(kernel_module, config.wrapped_backward_attr)
# Cache the downloaded kernel function in the config object
config.kernel_fn = kernel_func
except Exception as e:
logger.error(f"An error occurred while fetching kernel '{config.repo_id}' from the Hub: {e}")
@@ -733,7 +651,7 @@ def _wrapped_flash_attn_3(
) -> tuple[torch.Tensor, torch.Tensor]:
# Hardcoded for now because pytorch does not support tuple/int type hints
window_size = (-1, -1)
result = flash_attn_3_func(
out, lse, *_ = flash_attn_3_func(
q=q,
k=k,
v=v,
@@ -750,9 +668,7 @@ def _wrapped_flash_attn_3(
pack_gqa=pack_gqa,
deterministic=deterministic,
sm_margin=sm_margin,
return_attn_probs=True,
)
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
return out, lse
@@ -1155,258 +1071,6 @@ def _flash_attention_backward_op(
return grad_query, grad_key, grad_value
def _flash_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_forward_fn is None or wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_forward` and `_wrapped_flash_attn_backward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
window_size = (-1, -1)
softcap = 0.0
alibi_slopes = None
deterministic = False
grad_enabled = any(x.requires_grad for x in (query, key, value))
if grad_enabled or (_parallel_config is not None and _parallel_config.context_parallel_config._world_size > 1):
dropout_p = dropout_p if dropout_p > 0 else 1e-30
with torch.set_grad_enabled(grad_enabled):
out, lse, S_dmask, rng_state = wrapped_forward_fn(
query,
key,
value,
dropout_p,
scale,
is_causal,
window_size[0],
window_size[1],
softcap,
alibi_slopes,
return_lse,
)
lse = lse.permute(0, 2, 1).contiguous()
if _save_ctx:
ctx.save_for_backward(query, key, value, out, lse, rng_state)
ctx.dropout_p = dropout_p
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
return (out, lse) if return_lse else out
def _flash_attention_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention hub kernels must expose `_wrapped_flash_attn_backward` for context parallel execution."
)
query, key, value, out, lse, rng_state = ctx.saved_tensors
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
_ = wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
lse,
grad_query,
grad_key,
grad_value,
ctx.dropout_p,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.alibi_slopes,
ctx.deterministic,
rng_state,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _flash_attention_3_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
*,
window_size: tuple[int, int] = (-1, -1),
softcap: float = 0.0,
num_splits: int = 1,
pack_gqa: bool | None = None,
deterministic: bool = False,
sm_margin: int = 0,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for flash-attn 3 hub kernels.")
if dropout_p != 0.0:
raise ValueError("`dropout_p` is not yet supported for flash-attn 3 hub kernels.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for flash-attn 3 hub kernels.")
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_forward_fn = config.wrapped_forward_fn
if wrapped_forward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_forward` "
"for context parallel execution."
)
if scale is None:
scale = query.shape[-1] ** (-0.5)
out, softmax_lse, *_ = wrapped_forward_fn(
query,
key,
value,
None,
None, # k_new, v_new
None, # qv
None, # out
None,
None,
None, # cu_seqlens_q/k/k_new
None,
None, # seqused_q/k
None,
None, # max_seqlen_q/k
None,
None,
None, # page_table, kv_batch_idx, leftpad_k
None,
None,
None, # rotary_cos/sin, seqlens_rotary
None,
None,
None, # q_descale, k_descale, v_descale
scale,
causal=is_causal,
window_size_left=window_size[0],
window_size_right=window_size[1],
attention_chunk=0,
softcap=softcap,
num_splits=num_splits,
pack_gqa=pack_gqa,
sm_margin=sm_margin,
)
lse = softmax_lse.permute(0, 2, 1).contiguous() if return_lse else None
if _save_ctx:
ctx.save_for_backward(query, key, value, out, softmax_lse)
ctx.scale = scale
ctx.is_causal = is_causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.deterministic = deterministic
ctx.sm_margin = sm_margin
return (out, lse) if return_lse else out
def _flash_attention_3_hub_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
*args,
**kwargs,
):
config = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB]
wrapped_backward_fn = config.wrapped_backward_fn
if wrapped_backward_fn is None:
raise RuntimeError(
"Flash attention 3 hub kernels must expose `flash_attn_interface._flash_attn_backward` "
"for context parallel execution."
)
query, key, value, out, softmax_lse = ctx.saved_tensors
grad_query = torch.empty_like(query)
grad_key = torch.empty_like(key)
grad_value = torch.empty_like(value)
wrapped_backward_fn(
grad_out,
query,
key,
value,
out,
softmax_lse,
None,
None, # cu_seqlens_q, cu_seqlens_k
None,
None, # seqused_q, seqused_k
None,
None, # max_seqlen_q, max_seqlen_k
grad_query,
grad_key,
grad_value,
ctx.scale,
ctx.is_causal,
ctx.window_size[0],
ctx.window_size[1],
ctx.softcap,
ctx.deterministic,
ctx.sm_margin,
)
grad_query = grad_query[..., : grad_out.shape[-1]]
grad_key = grad_key[..., : grad_out.shape[-1]]
grad_value = grad_value[..., : grad_out.shape[-1]]
return grad_query, grad_key, grad_value
def _sage_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1445,46 +1109,6 @@ def _sage_attention_forward_op(
return (out, lse) if return_lse else out
def _sage_attention_hub_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: float | None = None,
enable_gqa: bool = False,
return_lse: bool = False,
_save_ctx: bool = True,
_parallel_config: "ParallelConfig" | None = None,
):
if attn_mask is not None:
raise ValueError("`attn_mask` is not yet supported for Sage attention.")
if dropout_p > 0.0:
raise ValueError("`dropout_p` is not yet supported for Sage attention.")
if enable_gqa:
raise ValueError("`enable_gqa` is not yet supported for Sage attention.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.SAGE_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
tensor_layout="NHD",
is_causal=is_causal,
sm_scale=scale,
return_lse=return_lse,
)
lse = None
if return_lse:
out, lse, *_ = out
lse = lse.permute(0, 2, 1).contiguous()
return (out, lse) if return_lse else out
def _sage_attention_backward_op(
ctx: torch.autograd.function.FunctionCtx,
grad_out: torch.Tensor,
@@ -1493,26 +1117,6 @@ def _sage_attention_backward_op(
raise NotImplementedError("Backward pass is not implemented for Sage attention.")
def _maybe_modify_attn_mask_npu(query: torch.Tensor, key: torch.Tensor, attn_mask: torch.Tensor | None = None):
# Skip Attention Mask if all values are 1, `None` mask can speedup the computation
if attn_mask is not None and torch.all(attn_mask != 0):
attn_mask = None
# Reshape Attention Mask: [batch_size, seq_len_k] -> [batch_size, 1, sqe_len_q, seq_len_k]
# https://www.hiascend.com/document/detail/zh/Pytorch/730/apiref/torchnpuCustomsapi/docs/context/torch_npu-npu_fusion_attention.md
if (
attn_mask is not None
and attn_mask.ndim == 2
and attn_mask.shape[0] == query.shape[0]
and attn_mask.shape[1] == key.shape[1]
):
B, Sq, Skv = attn_mask.shape[0], query.shape[1], key.shape[1]
attn_mask = ~attn_mask.to(torch.bool)
attn_mask = attn_mask.unsqueeze(1).expand(B, Sq, Skv).unsqueeze(1).contiguous()
return attn_mask
def _npu_attention_forward_op(
ctx: torch.autograd.function.FunctionCtx,
query: torch.Tensor,
@@ -1530,14 +1134,11 @@ def _npu_attention_forward_op(
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -2341,7 +1942,7 @@ def _flash_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_hub(
query: torch.Tensor,
@@ -2359,35 +1960,17 @@ def _flash_attention_hub(
raise ValueError("`attn_mask` is not supported for flash-attn 2.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
dropout_p,
is_causal,
scale,
False,
return_lse,
forward_op=_flash_attention_hub_forward_op,
backward_op=_flash_attention_hub_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
out = func(
q=query,
k=key,
v=value,
dropout_p=dropout_p,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if return_lse:
out, lse, *_ = out
return (out, lse) if return_lse else out
@@ -2534,7 +2117,7 @@ def _flash_attention_3(
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_3_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _flash_attention_3_hub(
query: torch.Tensor,
@@ -2549,68 +2132,33 @@ def _flash_attention_3_hub(
return_attn_probs: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if _parallel_config:
raise NotImplementedError(f"{AttentionBackendName._FLASH_3_HUB.value} is not implemented for parallelism yet.")
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 3.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName._FLASH_3_HUB].kernel_fn
if _parallel_config is None:
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
return (out[0], out[1]) if return_attn_probs else out
forward_op = functools.partial(
_flash_attention_3_hub_forward_op,
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
qv=None,
q_descale=None,
k_descale=None,
v_descale=None,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
return_attn_probs=return_attn_probs,
)
backward_op = functools.partial(
_flash_attention_3_hub_backward_op,
window_size=window_size,
softcap=softcap,
num_splits=1,
pack_gqa=None,
deterministic=deterministic,
sm_margin=0,
)
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_attn_probs,
forward_op=forward_op,
backward_op=backward_op,
_parallel_config=_parallel_config,
)
if return_attn_probs:
out, lse = out
return out, lse
return out
# When `return_attn_probs` is True, the above returns a tuple of
# actual outputs and lse.
return (out[0], out[1]) if return_attn_probs else out
@_AttentionBackendRegistry.register(
@@ -2703,7 +2251,7 @@ def _flash_varlen_attention_3(
key_packed = torch.cat(key_valid, dim=0)
value_packed = torch.cat(value_valid, dim=0)
result = flash_attn_3_varlen_func(
out, lse, *_ = flash_attn_3_varlen_func(
q=query_packed,
k=key_packed,
v=value_packed,
@@ -2713,13 +2261,7 @@ def _flash_varlen_attention_3(
max_seqlen_k=max_seqlen_k,
softmax_scale=scale,
causal=is_causal,
return_attn_probs=return_lse,
)
if isinstance(result, tuple):
out, lse, *_ = result
else:
out = result
lse = None
out = out.unflatten(0, (batch_size, -1))
return (out, lse) if return_lse else out
@@ -3126,17 +2668,16 @@ def _native_npu_attention(
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for NPU attention")
if return_lse:
raise ValueError("NPU attention backend does not support setting `return_lse=True`.")
if _parallel_config is None:
attn_mask = _maybe_modify_attn_mask_npu(query, key, attn_mask)
out = npu_fusion_attention(
query,
key,
value,
query.size(2), # num_heads
atten_mask=attn_mask,
input_layout="BSND",
pse=None,
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
@@ -3151,7 +2692,7 @@ def _native_npu_attention(
query,
key,
value,
attn_mask,
None,
dropout_p,
None,
scale,
@@ -3248,7 +2789,7 @@ def _sage_attention(
@_AttentionBackendRegistry.register(
AttentionBackendName.SAGE_HUB,
constraints=[_check_device_cuda, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=True,
supports_context_parallel=False,
)
def _sage_attention_hub(
query: torch.Tensor,
@@ -3276,23 +2817,6 @@ def _sage_attention_hub(
)
if return_lse:
out, lse, *_ = out
else:
out = _templated_context_parallel_attention(
query,
key,
value,
None,
0.0,
is_causal,
scale,
False,
return_lse,
forward_op=_sage_attention_hub_forward_op,
backward_op=_sage_attention_backward_op,
_parallel_config=_parallel_config,
)
if return_lse:
out, lse = out
return (out, lse) if return_lse else out

View File

@@ -30,126 +30,10 @@ class AutoModel(ConfigMixin):
def __init__(self, *args, **kwargs):
raise EnvironmentError(
f"{self.__class__.__name__} is designed to be instantiated "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)`, "
f"`{self.__class__.__name__}.from_config(config)`, or "
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
)
@classmethod
def from_config(cls, pretrained_model_name_or_path_or_dict: str | os.PathLike | dict | None = None, **kwargs):
r"""
Instantiate a model from a config dictionary or a pretrained model configuration file with random weights (no
pretrained weights are loaded).
Parameters:
pretrained_model_name_or_path_or_dict (`str`, `os.PathLike`, or `dict`):
Can be either:
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model
configuration hosted on the Hub.
- A path to a *directory* (for example `./my_model_directory`) containing a model configuration
file.
- A config dictionary.
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
is not used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model configuration, overriding the cached version if
it exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether to only load local model configuration files or not.
token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files.
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use.
trust_remote_code (`bool`, *optional*, defaults to `False`):
Whether to trust remote code.
subfolder (`str`, *optional*, defaults to `""`):
The subfolder location of a model file within a larger model repository on the Hub or locally.
Returns:
A model object instantiated from the config with random weights.
Example:
```py
from diffusers import AutoModel
model = AutoModel.from_config("stable-diffusion-v1-5/stable-diffusion-v1-5", subfolder="unet")
```
"""
subfolder = kwargs.pop("subfolder", None)
trust_remote_code = kwargs.pop("trust_remote_code", False)
hub_kwargs_names = [
"cache_dir",
"force_download",
"local_files_only",
"proxies",
"revision",
"token",
]
hub_kwargs = {name: kwargs.pop(name, None) for name in hub_kwargs_names}
if pretrained_model_name_or_path_or_dict is None:
raise ValueError(
"Please provide a `pretrained_model_name_or_path_or_dict` as the first positional argument."
)
if isinstance(pretrained_model_name_or_path_or_dict, (str, os.PathLike)):
pretrained_model_name_or_path = pretrained_model_name_or_path_or_dict
config = cls.load_config(pretrained_model_name_or_path, subfolder=subfolder, **hub_kwargs)
else:
config = pretrained_model_name_or_path_or_dict
pretrained_model_name_or_path = config.get("_name_or_path", None)
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
trust_remote_code = resolve_trust_remote_code(
trust_remote_code, pretrained_model_name_or_path, has_remote_code
)
if has_remote_code and trust_remote_code:
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
model_cls = get_class_from_dynamic_module(
pretrained_model_name_or_path,
subfolder=subfolder,
module_file=module_file,
class_name=class_name,
**hub_kwargs,
)
else:
if "_class_name" in config:
class_name = config["_class_name"]
library = "diffusers"
elif "model_type" in config:
class_name = "AutoModel"
library = "transformers"
else:
raise ValueError(
f"Couldn't find a model class associated with the config: {config}. Make sure the config "
"contains a `_class_name` or `model_type` key."
)
from ..pipelines.pipeline_loading_utils import ALL_IMPORTABLE_CLASSES, get_class_obj_and_candidates
model_cls, _ = get_class_obj_and_candidates(
library_name=library,
class_name=class_name,
importable_classes=ALL_IMPORTABLE_CLASSES,
pipelines=None,
is_pipeline_module=False,
)
if model_cls is None:
raise ValueError(f"AutoModel can't find a model linked to {class_name}.")
return model_cls.from_config(config, **kwargs)
@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_or_path: str | os.PathLike | None = None, **kwargs):

View File

@@ -18,7 +18,6 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_rae import AutoencoderRAE
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel

View File

@@ -1,692 +0,0 @@
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import sqrt
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.import_utils import is_transformers_available
from ...utils.torch_utils import randn_tensor
if is_transformers_available():
from transformers import (
Dinov2WithRegistersConfig,
Dinov2WithRegistersModel,
SiglipVisionConfig,
SiglipVisionModel,
ViTMAEConfig,
ViTMAEModel,
)
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import Attention
from ..embeddings import get_2d_sincos_pos_embed
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# Per-encoder forward functions
# ---------------------------------------------------------------------------
# Each function takes the raw transformers model + images and returns patch
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True)
unused_token_num = 5 # 1 CLS + 4 register tokens
return outputs.last_hidden_state[:, unused_token_num:]
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
return outputs.last_hidden_state
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
h, w = images.shape[2], images.shape[3]
patch_num = int(h * w // patch_size**2)
if patch_num * patch_size**2 != h * w:
raise ValueError("Image size should be divisible by patch size.")
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
outputs = model(images, noise, interpolate_pos_encoding=True)
return outputs.last_hidden_state[:, 1:] # remove cls token
# ---------------------------------------------------------------------------
# Encoder construction helpers
# ---------------------------------------------------------------------------
def _build_encoder(
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
if encoder_type == "dinov2":
config = Dinov2WithRegistersConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=518,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = Dinov2WithRegistersModel(config)
# RAE strips the final layernorm affine params (identity LN). Remove them from
# the architecture so `from_pretrained` doesn't leave them on the meta device.
model.layernorm.weight = None
model.layernorm.bias = None
elif encoder_type == "siglip2":
config = SiglipVisionConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=256,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = SiglipVisionModel(config)
# See dinov2 comment above.
model.vision_model.post_layernorm.weight = None
model.vision_model.post_layernorm.bias = None
elif encoder_type == "mae":
config = ViTMAEConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=224,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
mask_ratio=0.0,
)
model = ViTMAEModel(config)
# See dinov2 comment above.
model.layernorm.weight = None
model.layernorm.bias = None
else:
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
model.requires_grad_(False)
return model
_ENCODER_FORWARD_FNS = {
"dinov2": _dinov2_encoder_forward,
"siglip2": _siglip2_encoder_forward,
"mae": _mae_encoder_forward,
}
@dataclass
class RAEDecoderOutput(BaseOutput):
"""
Output of `RAEDecoder`.
Args:
logits (`torch.Tensor`):
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
"""
logits: torch.Tensor
class ViTMAEIntermediate(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = get_activation(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTMAEOutput(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
super().__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTMAELayer(nn.Module):
"""
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
"""
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
intermediate_size: int,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
)
self.attention = Attention(
query_dim=hidden_size,
heads=num_attention_heads,
dim_head=hidden_size // num_attention_heads,
dropout=attention_probs_dropout_prob,
bias=qkv_bias,
)
self.intermediate = ViTMAEIntermediate(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
)
self.output = ViTMAEOutput(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
)
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states))
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
return layer_output
class RAEDecoder(nn.Module):
"""Lightweight RAE decoder."""
def __init__(
self,
hidden_size: int = 768,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
num_patches: int = 256,
patch_size: int = 16,
num_channels: int = 3,
image_size: int = 256,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
self.decoder_hidden_size = decoder_hidden_size
self.patch_size = patch_size
self.num_channels = num_channels
self.image_size = image_size
self.num_patches = num_patches
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
self.decoder_layers = nn.ModuleList(
[
ViTMAELayer(
hidden_size=decoder_hidden_size,
num_attention_heads=decoder_num_attention_heads,
intermediate_size=decoder_intermediate_size,
qkv_bias=qkv_bias,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_act=hidden_act,
)
for _ in range(decoder_num_hidden_layers)
]
)
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
self.gradient_checkpointing = False
self._initialize_weights(num_patches)
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
def _initialize_weights(self, num_patches: int):
# Skip initialization when parameters are on meta device (e.g. during
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
# The weights are initialized.
if self.decoder_pos_embed.device.type == "meta":
return
grid_size = int(num_patches**0.5)
pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
grid_size,
cls_token=True,
extra_tokens=1,
output_type="pt",
device=self.decoder_pos_embed.device,
)
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
dim = self.decoder_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
b, l, c = x.shape
if l == self.num_patches:
return x
h = w = int(l**0.5)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
return x
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
patch_size, num_channels = self.patch_size, self.num_channels
original_image_size = (
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_h,
num_patches_w,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_h * patch_size,
num_patches_w * patch_size,
)
return pixel_values
def forward(
self,
hidden_states: torch.Tensor,
*,
interpolate_pos_encoding: bool = False,
drop_cls_token: bool = False,
return_dict: bool = True,
) -> RAEDecoderOutput | tuple[torch.Tensor]:
x = self.decoder_embed(hidden_states)
if drop_cls_token:
x_ = x[:, 1:, :]
x_ = self.interpolate_latent(x_)
else:
x_ = self.interpolate_latent(x)
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
x = torch.cat([cls_token, x_], dim=1)
if interpolate_pos_encoding:
if not drop_cls_token:
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
for layer_module in self.decoder_layers:
hidden_states = layer_module(hidden_states)
hidden_states = self.decoder_norm(hidden_states)
logits = self.decoder_pred(hidden_states)
logits = logits[:, 1:, :]
if not return_dict:
return (logits,)
return RAEDecoderOutput(logits=logits)
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
images from learned representations.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Args:
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
encoder_hidden_size (`int`, *optional*, defaults to `768`):
Hidden size of the encoder model.
encoder_patch_size (`int`, *optional*, defaults to `14`):
Patch size of the encoder model.
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
Number of hidden layers in the encoder model.
patch_size (`int`, *optional*, defaults to `16`):
Decoder patch size (used for unpatchify and decoder head).
encoder_input_size (`int`, *optional*, defaults to `224`):
Input size expected by the encoder.
image_size (`int`, *optional*):
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
encoder_patch_size) ** 2`.
num_channels (`int`, *optional*, defaults to `3`):
Number of input/output channels.
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
Channel-wise mean for encoder input normalization (ImageNet defaults).
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
Channel-wise std for encoder input normalization (ImageNet defaults).
latents_mean (`list` or `tuple`, *optional*):
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
lists.
latents_std (`list` or `tuple`, *optional*):
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
config-serializable lists.
noise_tau (`float`, *optional*, defaults to `0.0`):
Noise level for training (adds noise to latents during training).
reshape_to_2d (`bool`, *optional*, defaults to `True`):
Whether to reshape latents to 2D (B, C, H, W) format.
use_encoder_loss (`bool`, *optional*, defaults to `False`):
Whether to use encoder hidden states in the loss (for advanced training).
"""
# NOTE: gradient checkpointing is not wired up for this model yet.
_supports_gradient_checkpointing = False
_no_split_modules = ["ViTMAELayer"]
@register_to_config
def __init__(
self,
encoder_type: str = "dinov2",
encoder_hidden_size: int = 768,
encoder_patch_size: int = 14,
encoder_num_hidden_layers: int = 12,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
patch_size: int = 16,
encoder_input_size: int = 224,
image_size: int | None = None,
num_channels: int = 3,
encoder_norm_mean: list | None = None,
encoder_norm_std: list | None = None,
latents_mean: list | tuple | torch.Tensor | None = None,
latents_std: list | tuple | torch.Tensor | None = None,
noise_tau: float = 0.0,
reshape_to_2d: bool = True,
use_encoder_loss: bool = False,
scaling_factor: float = 1.0,
):
super().__init__()
if encoder_type not in _ENCODER_FORWARD_FNS:
raise ValueError(
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
)
if encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = patch_size
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
num_patches = (encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
def _to_config_compatible(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
if isinstance(value, tuple):
return [_to_config_compatible(v) for v in value]
if isinstance(value, list):
return [_to_config_compatible(v) for v in value]
return value
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
if value is None:
return None
if isinstance(value, torch.Tensor):
return value.detach().clone()
return torch.tensor(value, dtype=torch.float32)
latents_std_tensor = _as_optional_tensor(latents_std)
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
self.register_to_config(
latents_mean=_to_config_compatible(latents_mean),
latents_std=_to_config_compatible(latents_std),
)
# Frozen representation encoder (built from config, no downloads)
self.encoder: nn.Module = _build_encoder(
encoder_type=encoder_type,
hidden_size=encoder_hidden_size,
patch_size=encoder_patch_size,
num_hidden_layers=encoder_num_hidden_layers,
)
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
num_patches = (encoder_input_size // encoder_patch_size) ** 2
# Encoder input normalization stats (ImageNet defaults)
if encoder_norm_mean is None:
encoder_norm_mean = [0.485, 0.456, 0.406]
if encoder_norm_std is None:
encoder_norm_std = [0.229, 0.224, 0.225]
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
latents_mean_tensor = _as_optional_tensor(latents_mean)
if latents_mean_tensor is None:
latents_mean_tensor = torch.zeros(1)
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
if latents_std_tensor is None:
latents_std_tensor = torch.ones(1)
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
# ViT-MAE style decoder
self.decoder = RAEDecoder(
hidden_size=int(encoder_hidden_size),
decoder_hidden_size=int(decoder_hidden_size),
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
decoder_num_attention_heads=int(decoder_num_attention_heads),
decoder_intermediate_size=int(decoder_intermediate_size),
num_patches=int(num_patches),
patch_size=int(decoder_patch_size),
num_channels=int(num_channels),
image_size=int(image_size),
)
self.num_patches = int(num_patches)
self.decoder_patch_size = int(decoder_patch_size)
self.decoder_image_size = int(image_size)
# Slicing support (batch dimension) similar to other diffusers autoencoders
self.use_slicing = False
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
# Per-sample random sigma in [0, noise_tau]
noise_sigma = self.config.noise_tau * torch.rand(
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
)
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
x = F.interpolate(
x,
size=(self.config.encoder_input_size, self.config.encoder_input_size),
mode="bicubic",
align_corners=False,
)
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return (x - mean) / std
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return x * std + mean
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return (z - latents_mean) / (latents_std + 1e-5)
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return z * (latents_std + 1e-5) + latents_mean
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
x = self._resize_and_normalize(x)
if self.config.encoder_type == "mae":
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
else:
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
if self.training and self.config.noise_tau > 0:
tokens = self._noising(tokens, generator=generator)
if self.config.reshape_to_2d:
b, n, c = tokens.shape
side = int(sqrt(n))
if side * side != n:
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
else:
z = tokens
z = self._normalize_latents(z)
# Follow diffusers convention: optionally scale latents for diffusion
if self.config.scaling_factor != 1.0:
z = z * self.config.scaling_factor
return z
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> EncoderOutput | tuple[torch.Tensor]:
if self.use_slicing and x.shape[0] > 1:
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
else:
latents = self._encode(x, generator=generator)
if not return_dict:
return (latents,)
return EncoderOutput(latent=latents)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
# Undo scaling factor if applied at encode time
if self.config.scaling_factor != 1.0:
z = z / self.config.scaling_factor
z = self._denormalize_latents(z)
if self.config.reshape_to_2d:
b, c, h, w = z.shape
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
else:
tokens = z
logits = self.decoder(tokens, return_dict=True).logits
x_rec = self.decoder.unpatchify(logits)
x_rec = self._denormalize_image(x_rec)
return x_rec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> DecoderOutput | tuple[torch.Tensor]:
latents = self.encode(sample, return_dict=False, generator=generator)[0]
decoded = self.decode(latents, return_dict=False)[0]
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

View File

@@ -191,12 +191,7 @@ class CosmosControlNetModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
dim=1,
)
if condition_mask is not None:
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
else:
control_hidden_states = torch.cat(
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
)
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
padding_mask_resized = transforms.functional.resize(
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

View File

@@ -28,7 +28,6 @@ if is_torch_available():
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_helios import HeliosTransformer3DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -424,7 +424,7 @@ class Flux2SingleTransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None,
temb_mod: torch.Tensor,
temb_mod_params: tuple[torch.Tensor, torch.Tensor, torch.Tensor],
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
split_hidden_states: bool = False,
@@ -436,7 +436,7 @@ class Flux2SingleTransformerBlock(nn.Module):
text_seq_len = encoder_hidden_states.shape[1]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
mod_shift, mod_scale, mod_gate = Flux2Modulation.split(temb_mod, 1)[0]
mod_shift, mod_scale, mod_gate = temb_mod_params
norm_hidden_states = self.norm(hidden_states)
norm_hidden_states = (1 + mod_scale) * norm_hidden_states + mod_shift
@@ -498,18 +498,16 @@ class Flux2TransformerBlock(nn.Module):
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb_mod_img: torch.Tensor,
temb_mod_txt: torch.Tensor,
temb_mod_params_img: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
temb_mod_params_txt: tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...],
image_rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
joint_attention_kwargs: dict[str, Any] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
joint_attention_kwargs = joint_attention_kwargs or {}
# Modulation parameters shape: [1, 1, self.dim]
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = Flux2Modulation.split(temb_mod_img, 2)
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = Flux2Modulation.split(
temb_mod_txt, 2
)
(shift_msa, scale_msa, gate_msa), (shift_mlp, scale_mlp, gate_mlp) = temb_mod_params_img
(c_shift_msa, c_scale_msa, c_gate_msa), (c_shift_mlp, c_scale_mlp, c_gate_mlp) = temb_mod_params_txt
# Img stream
norm_hidden_states = self.norm1(hidden_states)
@@ -629,19 +627,15 @@ class Flux2Modulation(nn.Module):
self.linear = nn.Linear(dim, dim * 3 * self.mod_param_sets, bias=bias)
self.act_fn = nn.SiLU()
def forward(self, temb: torch.Tensor) -> torch.Tensor:
def forward(self, temb: torch.Tensor) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
mod = self.act_fn(temb)
mod = self.linear(mod)
return mod
@staticmethod
# split inside the transformer blocks, to avoid passing tuples into checkpoints https://github.com/huggingface/diffusers/issues/12776
def split(mod: torch.Tensor, mod_param_sets: int) -> tuple[tuple[torch.Tensor, torch.Tensor, torch.Tensor], ...]:
if mod.ndim == 2:
mod = mod.unsqueeze(1)
mod_params = torch.chunk(mod, 3 * mod_param_sets, dim=-1)
mod_params = torch.chunk(mod, 3 * self.mod_param_sets, dim=-1)
# Return tuple of 3-tuples of modulation params shift/scale/gate
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(mod_param_sets))
return tuple(mod_params[3 * i : 3 * (i + 1)] for i in range(self.mod_param_sets))
class Flux2Transformer2DModel(
@@ -830,7 +824,7 @@ class Flux2Transformer2DModel(
double_stream_mod_img = self.double_stream_modulation_img(temb)
double_stream_mod_txt = self.double_stream_modulation_txt(temb)
single_stream_mod = self.single_stream_modulation(temb)
single_stream_mod = self.single_stream_modulation(temb)[0]
# 2. Input projection for image (hidden_states) and conditioning text (encoder_hidden_states)
hidden_states = self.x_embedder(hidden_states)
@@ -867,8 +861,8 @@ class Flux2Transformer2DModel(
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb_mod_img=double_stream_mod_img,
temb_mod_txt=double_stream_mod_txt,
temb_mod_params_img=double_stream_mod_img,
temb_mod_params_txt=double_stream_mod_txt,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
@@ -890,7 +884,7 @@ class Flux2Transformer2DModel(
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=None,
temb_mod=single_stream_mod,
temb_mod_params=single_stream_mod,
image_rotary_emb=concat_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)

View File

@@ -1,814 +0,0 @@
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def pad_for_3d_conv(x, kernel_size):
b, c, t, h, w = x.shape
pt, ph, pw = kernel_size
pad_t = (pt - (t % pt)) % pt
pad_h = (ph - (h % ph)) % ph
pad_w = (pw - (w % pw)) % pw
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
def center_down_sample_3d(x, kernel_size):
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
def apply_rotary_emb_transposed(
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
):
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
out = torch.empty_like(hidden_states)
out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
return out.type_as(hidden_states)
def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
# In cross-attention layers, we can only fuse the KV projections into a single linear
query = attn.to_q(hidden_states)
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
return query, key, value
class HeliosOutputNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int):
temb = temb[:, -original_context_length:, :]
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device)
hidden_states = hidden_states[:, -original_context_length:, :]
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
return hidden_states
class HeliosAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: "HeliosAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
original_context_length: int = None,
) -> torch.Tensor:
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
query = apply_rotary_emb_transposed(query, rotary_emb)
key = apply_rotary_emb_transposed(key, rotary_emb)
if not attn.is_cross_attention and attn.is_amplify_history:
history_seq_len = hidden_states.shape[1] - original_context_length
if history_seq_len > 0:
scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0)
if attn.history_scale_mode == "per_head":
scale_key = scale_key.view(1, 1, -1, 1)
key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class HeliosAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = HeliosAttnProcessor
_available_processors = [HeliosAttnProcessor]
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: int | None = None,
cross_attention_dim_head: int | None = None,
processor=None,
is_cross_attention=None,
is_amplify_history=False,
history_scale_mode="per_head", # [scalar, per_head]
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList(
[
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
]
)
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
self.is_amplify_history = is_amplify_history
if is_amplify_history:
if history_scale_mode == "scalar":
self.history_key_scale = nn.Parameter(torch.ones(1))
elif history_scale_mode == "per_head":
self.history_key_scale = nn.Parameter(torch.ones(heads))
else:
raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}")
self.history_scale_mode = history_scale_mode
self.max_scale = 10.0
def fuse_projections(self):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
if not getattr(self, "fused_projections", False):
return
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_kv"):
delattr(self, "to_added_kv")
self.fused_projections = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
original_context_length: int = None,
**kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states,
attention_mask,
rotary_emb,
original_context_length,
**kwargs,
)
class HeliosTimeTextEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
is_return_encoder_hidden_states: bool = True,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
if encoder_hidden_states is not None and is_return_encoder_hidden_states:
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class HeliosRotaryPosEmbed(nn.Module):
def __init__(self, rope_dim, theta):
super().__init__()
self.DT, self.DY, self.DX = rope_dim
self.theta = theta
self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False)
self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False)
self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False)
def _get_freqs_base(self, dim):
return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim))
@torch.no_grad()
def get_frequency_batched(self, freqs_base, pos):
freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos)
freqs = freqs.repeat_interleave(2, dim=0)
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)
grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij")
return grid_y, grid_x
@torch.no_grad()
def forward(self, frame_indices, height, width, device):
batch_size = frame_indices.shape[0]
num_frames = frame_indices.shape[1]
frame_indices = frame_indices.to(device=device, dtype=torch.float32)
grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))
grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width)
grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)
grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)
freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t)
freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch)
freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch)
result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0)
return result.permute(1, 0, 2, 3, 4)
@maybe_allow_in_graph
class HeliosTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
guidance_cross_attn: bool = False,
is_amplify_history: bool = False,
history_scale_mode: str = "per_head", # [scalar, per_head]
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = HeliosAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
processor=HeliosAttnProcessor(),
is_amplify_history=is_amplify_history,
history_scale_mode=history_scale_mode,
)
# 2. Cross-attention
self.attn2 = HeliosAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
cross_attention_dim_head=dim // num_heads,
processor=HeliosAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
# 4. Guidance cross-attention
self.guidance_cross_attn = guidance_cross_attn
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
original_context_length: int = None,
) -> torch.Tensor:
if temb.ndim == 4:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
None,
None,
rotary_emb,
original_context_length,
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
if self.guidance_cross_attn:
history_seq_len = hidden_states.shape[1] - original_context_length
history_hidden_states, hidden_states = torch.split(
hidden_states, [history_seq_len, original_context_length], dim=1
)
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
None,
None,
original_context_length,
)
hidden_states = hidden_states + attn_output
hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1)
else:
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
None,
None,
original_context_length,
)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class HeliosTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r"""
A Transformer model for video-like data used in the Helios model.
Args:
patch_size (`tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = [
"patch_embedding",
"patch_short",
"patch_mid",
"patch_long",
"condition_embedder",
"norm",
]
_no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"]
_keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table",
"norm1",
"norm2",
"norm3",
"history_key_scale",
]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["HeliosTransformerBlock"]
_cp_plan = {
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
"rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
self,
patch_size: tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: str | None = "rms_norm_across_heads",
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
rope_dim: tuple[int, ...] = (44, 42, 42),
rope_theta: float = 10000.0,
guidance_cross_attn: bool = True,
zero_history_timestep: bool = True,
has_multi_term_memory_patch: bool = True,
is_amplify_history: bool = False,
history_scale_mode: str = "per_head", # [scalar, per_head]
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Initial Multi Term Memory Patch
self.zero_history_timestep = zero_history_timestep
if has_multi_term_memory_patch:
self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.patch_mid = nn.Conv3d(
in_channels,
inner_dim,
kernel_size=tuple(2 * p for p in patch_size),
stride=tuple(2 * p for p in patch_size),
)
self.patch_long = nn.Conv3d(
in_channels,
inner_dim,
kernel_size=tuple(4 * p for p in patch_size),
stride=tuple(4 * p for p in patch_size),
)
# 3. Condition embeddings
self.condition_embedder = HeliosTimeTextEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
)
# 4. Transformer blocks
self.blocks = nn.ModuleList(
[
HeliosTransformerBlock(
inner_dim,
ffn_dim,
num_attention_heads,
qk_norm,
cross_attn_norm,
eps,
added_kv_proj_dim,
guidance_cross_attn=guidance_cross_attn,
is_amplify_history=is_amplify_history,
history_scale_mode=history_scale_mode,
)
for _ in range(num_layers)
]
)
# 5. Output norm & projection
self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
# ------------ Stage 1 ------------
indices_hidden_states=None,
indices_latents_history_short=None,
indices_latents_history_mid=None,
indices_latents_history_long=None,
latents_history_short=None,
latents_history_mid=None,
latents_history_long=None,
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
# 1. Input
batch_size = hidden_states.shape[0]
p_t, p_h, p_w = self.config.patch_size
# 2. Process noisy latents
hidden_states = self.patch_embedding(hidden_states)
_, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape
if indices_hidden_states is None:
indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
rotary_emb = self.rope(
frame_indices=indices_hidden_states,
height=post_patch_height,
width=post_patch_width,
device=hidden_states.device,
)
rotary_emb = rotary_emb.flatten(2).transpose(1, 2)
original_context_length = hidden_states.shape[1]
# 3. Process short history latents
if latents_history_short is not None and indices_latents_history_short is not None:
latents_history_short = self.patch_short(latents_history_short)
_, _, _, H1, W1 = latents_history_short.shape
latents_history_short = latents_history_short.flatten(2).transpose(1, 2)
rotary_emb_history_short = self.rope(
frame_indices=indices_latents_history_short,
height=H1,
width=W1,
device=latents_history_short.device,
)
rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)
# 4. Process mid history latents
if latents_history_mid is not None and indices_latents_history_mid is not None:
latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))
latents_history_mid = self.patch_mid(latents_history_mid)
latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)
rotary_emb_history_mid = self.rope(
frame_indices=indices_latents_history_mid,
height=H1,
width=W1,
device=latents_history_mid.device,
)
rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))
rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2))
rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)
# 5. Process long history latents
if latents_history_long is not None and indices_latents_history_long is not None:
latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))
latents_history_long = self.patch_long(latents_history_long)
latents_history_long = latents_history_long.flatten(2).transpose(1, 2)
rotary_emb_history_long = self.rope(
frame_indices=indices_latents_history_long,
height=H1,
width=W1,
device=latents_history_long.device,
)
rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4))
rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4))
rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)
history_context_length = hidden_states.shape[1] - original_context_length
if indices_hidden_states is not None and self.zero_history_timestep:
timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device)
temb_t0, timestep_proj_t0, _ = self.condition_embedder(
timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False
)
temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1)
timestep_proj_t0 = (
timestep_proj_t0.unflatten(-1, (6, -1))
.view(1, 6, 1, -1)
.expand(batch_size, -1, history_context_length, -1)
)
temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
if indices_hidden_states is not None and not self.zero_history_timestep:
main_repeat_size = hidden_states.shape[1]
else:
main_repeat_size = original_context_length
temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)
timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1)
if indices_hidden_states is not None and self.zero_history_timestep:
temb = torch.cat([temb_t0, temb], dim=1)
timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)
if timestep_proj.ndim == 4:
timestep_proj = timestep_proj.permute(0, 2, 1, 3)
# 6. Transformer blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
rotary_emb = rotary_emb.contiguous()
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
original_context_length,
)
else:
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
original_context_length,
)
# 7. Normalization
hidden_states = self.norm_out(hidden_states, temb, original_context_length)
hidden_states = self.proj_out(hidden_states)
# 8. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -164,11 +164,7 @@ def compute_text_seq_len_from_mask(
position_ids = torch.arange(text_seq_len, device=encoder_hidden_states.device, dtype=torch.long)
active_positions = torch.where(encoder_hidden_states_mask, position_ids, position_ids.new_zeros(()))
has_active = encoder_hidden_states_mask.any(dim=1)
per_sample_len = torch.where(
has_active,
active_positions.max(dim=1).values + 1,
torch.as_tensor(text_seq_len, device=encoder_hidden_states.device),
)
per_sample_len = torch.where(has_active, active_positions.max(dim=1).values + 1, torch.as_tensor(text_seq_len))
return text_seq_len, per_sample_len, encoder_hidden_states_mask

View File

@@ -21,8 +21,21 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_flux"] = ["FluxAutoBlocks"]
_import_structure["modular_blocks_flux_kontext"] = ["FluxKontextAutoBlocks"]
_import_structure["encoders"] = ["FluxTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"AUTO_BLOCKS_KONTEXT",
"FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -32,8 +45,21 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_flux import FluxAutoBlocks
from .modular_blocks_flux_kontext import FluxKontextAutoBlocks
from .encoders import FluxTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
AUTO_BLOCKS_KONTEXT,
FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys

View File

@@ -205,7 +205,7 @@ class FluxKontextProcessImagesInputStep(ModularPipelineBlocks):
return components, state
class FluxVaeEncoderStep(ModularPipelineBlocks):
class FluxVaeEncoderDynamicStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(

View File

@@ -121,7 +121,7 @@ class FluxTextInputStep(ModularPipelineBlocks):
# Adapted from `QwenImageAdditionalInputsStep`
class FluxAdditionalInputsStep(ModularPipelineBlocks):
class FluxInputsDynamicStep(ModularPipelineBlocks):
model_name = "flux"
def __init__(
@@ -243,7 +243,7 @@ class FluxAdditionalInputsStep(ModularPipelineBlocks):
return components, state
class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep):
class FluxKontextInputsDynamicStep(FluxInputsDynamicStep):
model_name = "flux-kontext"
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
@@ -256,7 +256,7 @@ class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep):
continue
# 1. Calculate height/width from latents
# Unlike the `FluxAdditionalInputsStep`, we don't overwrite the `block.height` and `block.width`
# Unlike the `FluxInputsDynamicStep`, we don't overwrite the `block.height` and `block.width`
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
if not hasattr(block_state, "image_height"):
block_state.image_height = height
@@ -303,7 +303,6 @@ class FluxKontextAdditionalInputsStep(FluxAdditionalInputsStep):
class FluxKontextSetResolutionStep(ModularPipelineBlocks):
model_name = "flux-kontext"
@property
def description(self):
return (
"Determines the height and width to be used during the subsequent computations.\n"

View File

@@ -0,0 +1,446 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxKontextRoPEInputsStep,
FluxPrepareLatentsStep,
FluxRoPEInputsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep, FluxKontextDenoiseStep
from .encoders import (
FluxKontextProcessImagesInputStep,
FluxProcessImagesInputStep,
FluxTextEncoderStep,
FluxVaeEncoderDynamicStep,
)
from .inputs import (
FluxInputsDynamicStep,
FluxKontextInputsDynamicStep,
FluxKontextSetResolutionStep,
FluxTextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# vae encoder (run before before_denoise)
FluxImg2ImgVaeEncoderBlocks = InsertableDict(
[("preprocess", FluxProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep())]
)
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = FluxImg2ImgVaeEncoderBlocks.values()
block_names = FluxImg2ImgVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxImg2ImgVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is not provided, step will be skipped."
)
# Flux Kontext vae encoder (run before before_denoise)
FluxKontextVaeEncoderBlocks = InsertableDict(
[("preprocess", FluxKontextProcessImagesInputStep()), ("encode", FluxVaeEncoderDynamicStep(sample_mode="argmax"))]
)
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = FluxKontextVaeEncoderBlocks.values()
block_names = FluxKontextVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
block_classes = [FluxKontextVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxKontextVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is not provided, step will be skipped."
)
# before_denoise: text2img
FluxBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxRoPEInputsStep()),
]
)
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = FluxBeforeDenoiseBlocks.values()
block_names = FluxBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
# before_denoise: img2img
FluxImg2ImgBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
("prepare_rope_inputs", FluxRoPEInputsStep()),
]
)
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = FluxImg2ImgBeforeDenoiseBlocks.values()
block_names = FluxImg2ImgBeforeDenoiseBlocks.keys()
@property
def description(self):
return "Before denoise step that prepare the inputs for the denoise step for img2img task."
# before_denoise: all task (text2img, img2img)
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
# before_denoise: FluxKontext
FluxKontextBeforeDenoiseBlocks = InsertableDict(
[
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
]
)
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
block_classes = FluxKontextBeforeDenoiseBlocks.values()
block_names = FluxKontextBeforeDenoiseBlocks.keys()
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step\n"
"for img2img/text2img task for Flux Kontext."
)
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxKontextBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
# denoise: text2image
class FluxAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# denoise: Flux Kontext
class FluxKontextAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [FluxKontextDenoiseStep]
block_names = ["denoise"]
block_trigger_inputs = [None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents for Flux Kontext. "
"This is a auto pipeline block that works for text2image and img2img tasks."
" - `FluxDenoiseStep` (denoise) for text2image and img2img tasks."
)
# decode: all task (text2img, img2img)
class FluxAutoDecodeStep(AutoPipelineBlocks):
block_classes = [FluxDecodeStep]
block_names = ["non-inpaint"]
block_trigger_inputs = [None]
@property
def description(self):
return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`"
# inputs: text2image/img2img
FluxImg2ImgBlocks = InsertableDict(
[("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())]
)
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = FluxImg2ImgBlocks.values()
block_names = FluxImg2ImgBlocks.keys()
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
class FluxAutoInputStep(AutoPipelineBlocks):
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
)
# inputs: Flux Kontext
FluxKontextBlocks = InsertableDict(
[
("set_resolution", FluxKontextSetResolutionStep()),
("text_inputs", FluxTextInputStep()),
("additional_inputs", FluxKontextInputsDynamicStep()),
]
)
class FluxKontextInputStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = FluxKontextBlocks.values()
block_names = FluxKontextBlocks.keys()
@property
def description(self):
return (
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
class FluxKontextAutoInputStep(AutoPipelineBlocks):
block_classes = [FluxKontextInputStep, FluxTextInputStep]
# block_classes = [FluxKontextInputStep]
block_names = ["img2img", "text2img"]
# block_names = ["img2img"]
block_trigger_inputs = ["image_latents", None]
# block_trigger_inputs = ["image_latents"]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxKontextInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
)
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux"
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
)
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextAutoDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process. \n"
+ " - `FluxKontextAutoInputStep` (input) standardizes the inputs for the denoising step.\n"
+ " - `FluxKontextAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n"
+ " - `FluxKontextAutoDenoiseStep` (denoise) iteratively denoises the latents.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
)
# Auto blocks (text2image and img2img)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxAutoVaeEncoderStep()),
("denoise", FluxCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
AUTO_BLOCKS_KONTEXT = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
("denoise", FluxKontextCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
class FluxAutoBlocks(SequentialPipelineBlocks):
model_name = "flux"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image and image-to-image using Flux.\n"
+ "- for text-to-image generation, all you need to provide is `prompt`\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`"
)
class FluxKontextAutoBlocks(FluxAutoBlocks):
model_name = "flux-kontext"
block_classes = AUTO_BLOCKS_KONTEXT.values()
block_names = AUTO_BLOCKS_KONTEXT.keys()
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("input", FluxTextInputStep()),
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxRoPEInputsStep()),
("denoise", FluxDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxVaeEncoderDynamicStep()),
("input", FluxImg2ImgInputStep()),
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxImg2ImgSetTimestepsStep()),
("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()),
("prepare_rope_inputs", FluxRoPEInputsStep()),
("denoise", FluxDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
FLUX_KONTEXT_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxVaeEncoderDynamicStep(sample_mode="argmax")),
("input", FluxKontextInputStep()),
("prepare_latents", FluxPrepareLatentsStep()),
("set_timesteps", FluxSetTimestepsStep()),
("prepare_rope_inputs", FluxKontextRoPEInputsStep()),
("denoise", FluxKontextDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"auto": AUTO_BLOCKS,
"auto_kontext": AUTO_BLOCKS_KONTEXT,
"kontext": FLUX_KONTEXT_BLOCKS,
}

View File

@@ -1,586 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
FluxImg2ImgPrepareLatentsStep,
FluxImg2ImgSetTimestepsStep,
FluxPrepareLatentsStep,
FluxRoPEInputsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxDenoiseStep
from .encoders import (
FluxProcessImagesInputStep,
FluxTextEncoderStep,
FluxVaeEncoderStep,
)
from .inputs import (
FluxAdditionalInputsStep,
FluxTextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# vae encoder (run before before_denoise)
# auto_docstring
class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
"""
Vae encoder step that preprocess andencode the image inputs into their latent representations.
Components:
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
Inputs:
resized_image (`None`, *optional*):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
processed_image (`None`):
TODO: Add description.
image_latents (`Tensor`):
The latents representing the reference image
"""
model_name = "flux"
block_classes = [FluxProcessImagesInputStep(), FluxVaeEncoderStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
# auto_docstring
class FluxAutoVaeEncoderStep(AutoPipelineBlocks):
"""
Vae encoder step that encode the image inputs into their latent representations.
This is an auto pipeline block that works for img2img tasks.
- `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided. - if `image` is not provided,
step will be skipped.
Components:
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
Inputs:
resized_image (`None`, *optional*):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
processed_image (`None`):
TODO: Add description.
image_latents (`Tensor`):
The latents representing the reference image
"""
model_name = "flux"
block_classes = [FluxImg2ImgVaeEncoderStep]
block_names = ["img2img"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for img2img tasks.\n"
+ " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided."
+ " - if `image` is not provided, step will be skipped."
)
# before_denoise: text2img
# auto_docstring
class FluxBeforeDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepares the inputs for the denoise step in text-to-image generation.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux"
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()]
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation."
# before_denoise: img2img
# auto_docstring
class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs for the denoise step for img2img task.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
image_latents (`Tensor`):
The image latents to use for the denoising process. Can be generated in vae encoder and packed in input
step.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
initial_noise (`Tensor`):
The initial random noised used for inpainting denoising.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux"
block_classes = [
FluxPrepareLatentsStep(),
FluxImg2ImgSetTimestepsStep(),
FluxImg2ImgPrepareLatentsStep(),
FluxRoPEInputsStep(),
]
block_names = ["prepare_latents", "set_timesteps", "prepare_img2img_latents", "prepare_rope_inputs"]
@property
def description(self):
return "Before denoise step that prepare the inputs for the denoise step for img2img task."
# before_denoise: all task (text2img, img2img)
# auto_docstring
class FluxAutoBeforeDenoiseStep(AutoPipelineBlocks):
"""
Before denoise step that prepare the inputs for the denoise step.
This is an auto pipeline block that works for text2image.
- `FluxBeforeDenoiseStep` (text2image) is used.
- `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`):
TODO: Add description.
width (`int`):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
image_latents (`Tensor`, *optional*):
The image latents to use for the denoising process. Can be generated in vae encoder and packed in input
step.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
initial_noise (`Tensor`):
The initial random noised used for inpainting denoising.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux"
block_classes = [FluxImg2ImgBeforeDenoiseStep, FluxBeforeDenoiseStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxImg2ImgBeforeDenoiseStep` (img2img) is used when only `image_latents` is provided.\n"
)
# inputs: text2image/img2img
# auto_docstring
class FluxImg2ImgInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the img2img denoising step. It:
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
Outputs:
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
dtype (`dtype`):
Data type of model tensor inputs (determined by `prompt_embeds`)
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation
pooled_prompt_embeds (`Tensor`):
pooled text embeddings used to guide the image generation
image_height (`int`):
The height of the image latents
image_width (`int`):
The width of the image latents
"""
model_name = "flux"
block_classes = [FluxTextInputStep(), FluxAdditionalInputsStep()]
block_names = ["text_inputs", "additional_inputs"]
@property
def description(self):
return "Input step that prepares the inputs for the img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
# auto_docstring
class FluxAutoInputStep(AutoPipelineBlocks):
"""
Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size,
and patchified.
This is an auto pipeline block that works for text2image/img2img tasks.
- `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.
- `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
Outputs:
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
dtype (`dtype`):
Data type of model tensor inputs (determined by `prompt_embeds`)
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation
pooled_prompt_embeds (`Tensor`):
pooled text embeddings used to guide the image generation
image_height (`int`):
The height of the image latents
image_width (`int`):
The width of the image latents
"""
model_name = "flux"
block_classes = [FluxImg2ImgInputStep, FluxTextInputStep]
block_names = ["img2img", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n"
+ " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n"
)
# auto_docstring
class FluxCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core step that performs the denoising process for Flux.
This step supports text-to-image and image-to-image tasks for Flux:
- for image-to-image generation, you need to provide `image_latents`
- for text-to-image generation, all you need to provide is prompt embeddings.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux"
block_classes = [FluxAutoInputStep, FluxAutoBeforeDenoiseStep, FluxDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process for Flux.\n"
+ "This step supports text-to-image and image-to-image tasks for Flux:\n"
+ " - for image-to-image generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
)
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# Auto blocks (text2image and img2img)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxAutoVaeEncoderStep()),
("denoise", FluxCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
# auto_docstring
class FluxAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image and image-to-image using Flux.
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `image`, `prompt`
Components:
text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2
(`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler
(`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
prompt_2 (`None`, *optional*):
TODO: Add description.
max_sequence_length (`int`, *optional*, defaults to 512):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
resized_image (`None`, *optional*):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
model_name = "flux"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
}
@property
def description(self):
return "Auto Modular pipeline for text-to-image and image-to-image using Flux."
@property
def outputs(self):
return [OutputParam.template("images")]

View File

@@ -1,585 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
FluxKontextRoPEInputsStep,
FluxPrepareLatentsStep,
FluxRoPEInputsStep,
FluxSetTimestepsStep,
)
from .decoders import FluxDecodeStep
from .denoise import FluxKontextDenoiseStep
from .encoders import (
FluxKontextProcessImagesInputStep,
FluxTextEncoderStep,
FluxVaeEncoderStep,
)
from .inputs import (
FluxKontextAdditionalInputsStep,
FluxKontextSetResolutionStep,
FluxTextInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# Flux Kontext vae encoder (run before before_denoise)
# auto_docstring
class FluxKontextVaeEncoderStep(SequentialPipelineBlocks):
"""
Vae encoder step that preprocess andencode the image inputs into their latent representations.
Components:
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
_auto_resize (`bool`, *optional*, defaults to True):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
processed_image (`None`):
TODO: Add description.
image_latents (`Tensor`):
The latents representing the reference image
"""
model_name = "flux-kontext"
block_classes = [FluxKontextProcessImagesInputStep(), FluxVaeEncoderStep(sample_mode="argmax")]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return "Vae encoder step that preprocess andencode the image inputs into their latent representations."
# auto_docstring
class FluxKontextAutoVaeEncoderStep(AutoPipelineBlocks):
"""
Vae encoder step that encode the image inputs into their latent representations.
This is an auto pipeline block that works for image-conditioned tasks.
- `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided. - if `image` is not
provided, step will be skipped.
Components:
image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
_auto_resize (`bool`, *optional*, defaults to True):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
processed_image (`None`):
TODO: Add description.
image_latents (`Tensor`):
The latents representing the reference image
"""
model_name = "flux-kontext"
block_classes = [FluxKontextVaeEncoderStep]
block_names = ["image_conditioned"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"Vae encoder step that encode the image inputs into their latent representations.\n"
+ "This is an auto pipeline block that works for image-conditioned tasks.\n"
+ " - `FluxKontextVaeEncoderStep` (image_conditioned) is used when only `image` is provided."
+ " - if `image` is not provided, step will be skipped."
)
# before_denoise: text2img
# auto_docstring
class FluxKontextBeforeDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepares the inputs for the denoise step for Flux Kontext
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux-kontext"
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxRoPEInputsStep()]
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
@property
def description(self):
return "Before denoise step that prepares the inputs for the denoise step for Flux Kontext\n"
"for text-to-image tasks."
# before_denoise: image-conditioned
# auto_docstring
class FluxKontextImageConditionedBeforeDenoiseStep(SequentialPipelineBlocks):
"""
Before denoise step that prepare the inputs for the denoise step for Flux Kontext
for image-conditioned tasks.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
image_height (`None`, *optional*):
TODO: Add description.
image_width (`None`, *optional*):
TODO: Add description.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux-kontext"
block_classes = [FluxPrepareLatentsStep(), FluxSetTimestepsStep(), FluxKontextRoPEInputsStep()]
block_names = ["prepare_latents", "set_timesteps", "prepare_rope_inputs"]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step for Flux Kontext\n"
"for image-conditioned tasks."
)
# auto_docstring
class FluxKontextAutoBeforeDenoiseStep(AutoPipelineBlocks):
"""
Before denoise step that prepare the inputs for the denoise step.
This is an auto pipeline block that works for text2image.
- `FluxKontextBeforeDenoiseStep` (text2image) is used.
- `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is
provided.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`)
Inputs:
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_images_per_prompt (`int`, *optional*, defaults to 1):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be `batch_size * num_images_per_prompt`.
Can be generated in input step.
dtype (`dtype`, *optional*):
The dtype of the model inputs
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
image_height (`None`, *optional*):
TODO: Add description.
image_width (`None`, *optional*):
TODO: Add description.
prompt_embeds (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
The initial latents to use for the denoising process
timesteps (`Tensor`):
The timesteps to use for inference
num_inference_steps (`int`):
The number of denoising steps to perform at inference time
guidance (`Tensor`):
Optional guidance to be used.
txt_ids (`list`):
The sequence lengths of the prompt embeds, used for RoPE calculation.
img_ids (`list`):
The sequence lengths of the image latents, used for RoPE calculation.
"""
model_name = "flux-kontext"
block_classes = [FluxKontextImageConditionedBeforeDenoiseStep, FluxKontextBeforeDenoiseStep]
block_names = ["image_conditioned", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Before denoise step that prepare the inputs for the denoise step.\n"
+ "This is an auto pipeline block that works for text2image.\n"
+ " - `FluxKontextBeforeDenoiseStep` (text2image) is used.\n"
+ " - `FluxKontextImageConditionedBeforeDenoiseStep` (image_conditioned) is used when only `image_latents` is provided.\n"
)
# inputs: Flux Kontext
# auto_docstring
class FluxKontextInputStep(SequentialPipelineBlocks):
"""
Input step that prepares the inputs for the both text2img and img2img denoising step. It:
- make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).
- update height/width based `image_latents`, patchify `image_latents`.
Inputs:
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
max_area (`int`, *optional*, defaults to 1048576):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
image_latents (`None`, *optional*):
TODO: Add description.
Outputs:
height (`int`):
The height of the initial noisy latents
width (`int`):
The width of the initial noisy latents
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
dtype (`dtype`):
Data type of model tensor inputs (determined by `prompt_embeds`)
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation
pooled_prompt_embeds (`Tensor`):
pooled text embeddings used to guide the image generation
image_height (`int`):
The height of the image latents
image_width (`int`):
The width of the image latents
"""
model_name = "flux-kontext"
block_classes = [FluxKontextSetResolutionStep(), FluxTextInputStep(), FluxKontextAdditionalInputsStep()]
block_names = ["set_resolution", "text_inputs", "additional_inputs"]
@property
def description(self):
return (
"Input step that prepares the inputs for the both text2img and img2img denoising step. It:\n"
" - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n"
" - update height/width based `image_latents`, patchify `image_latents`."
)
# auto_docstring
class FluxKontextAutoInputStep(AutoPipelineBlocks):
"""
Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size,
and patchified.
This is an auto pipeline block that works for text2image/img2img tasks.
- `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.
- `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present.
Inputs:
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
max_area (`int`, *optional*, defaults to 1048576):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
image_latents (`None`, *optional*):
TODO: Add description.
Outputs:
height (`int`):
The height of the initial noisy latents
width (`int`):
The width of the initial noisy latents
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt
dtype (`dtype`):
Data type of model tensor inputs (determined by `prompt_embeds`)
prompt_embeds (`Tensor`):
text embeddings used to guide the image generation
pooled_prompt_embeds (`Tensor`):
pooled text embeddings used to guide the image generation
image_height (`int`):
The height of the image latents
image_width (`int`):
The width of the image latents
"""
model_name = "flux-kontext"
block_classes = [FluxKontextInputStep, FluxTextInputStep]
block_names = ["image_conditioned", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n"
" This is an auto pipeline block that works for text2image/img2img tasks.\n"
+ " - `FluxKontextInputStep` (image_conditioned) is used when `image_latents` is provided.\n"
+ " - `FluxKontextInputStep` is also capable of handling text2image task when `image_latent` isn't present."
)
# auto_docstring
class FluxKontextCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core step that performs the denoising process for Flux Kontext.
This step supports text-to-image and image-conditioned tasks for Flux Kontext:
- for image-conditioned generation, you need to provide `image_latents`
- for text-to-image generation, all you need to provide is prompt embeddings.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
Inputs:
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
max_area (`int`, *optional*, defaults to 1048576):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
pooled_prompt_embeds (`Tensor`, *optional*):
Pre-generated pooled text embeddings. Can be generated from text_encoder step.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux-kontext"
block_classes = [FluxKontextAutoInputStep, FluxKontextAutoBeforeDenoiseStep, FluxKontextDenoiseStep]
block_names = ["input", "before_denoise", "denoise"]
@property
def description(self):
return (
"Core step that performs the denoising process for Flux Kontext.\n"
+ "This step supports text-to-image and image-conditioned tasks for Flux Kontext:\n"
+ " - for image-conditioned generation, you need to provide `image_latents`\n"
+ " - for text-to-image generation, all you need to provide is prompt embeddings."
)
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
AUTO_BLOCKS_KONTEXT = InsertableDict(
[
("text_encoder", FluxTextEncoderStep()),
("vae_encoder", FluxKontextAutoVaeEncoderStep()),
("denoise", FluxKontextCoreDenoiseStep()),
("decode", FluxDecodeStep()),
]
)
# auto_docstring
class FluxKontextAutoBlocks(SequentialPipelineBlocks):
"""
Modular pipeline for image-to-image using Flux Kontext.
Supported workflows:
- `image_conditioned`: requires `image`, `prompt`
- `text2image`: requires `prompt`
Components:
text_encoder (`CLIPTextModel`) tokenizer (`CLIPTokenizer`) text_encoder_2 (`T5EncoderModel`) tokenizer_2
(`T5TokenizerFast`) image_processor (`VaeImageProcessor`) vae (`AutoencoderKL`) scheduler
(`FlowMatchEulerDiscreteScheduler`) transformer (`FluxTransformer2DModel`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
prompt_2 (`None`, *optional*):
TODO: Add description.
max_sequence_length (`int`, *optional*, defaults to 512):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
_auto_resize (`bool`, *optional*, defaults to True):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
max_area (`int`, *optional*, defaults to 1048576):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 3.5):
TODO: Add description.
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
model_name = "flux-kontext"
block_classes = AUTO_BLOCKS_KONTEXT.values()
block_names = AUTO_BLOCKS_KONTEXT.keys()
_workflow_map = {
"image_conditioned": {"image": True, "prompt": True},
"text2image": {"prompt": True},
}
@property
def description(self):
return "Modular pipeline for image-to-image using Flux Kontext."
@property
def outputs(self):
return [OutputParam.template("images")]

View File

@@ -21,14 +21,44 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["encoders"] = ["Flux2RemoteTextEncoderStep"]
_import_structure["modular_blocks_flux2"] = ["Flux2AutoBlocks"]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks"]
_import_structure["modular_blocks_flux2_klein_base"] = ["Flux2KleinBaseAutoBlocks"]
_import_structure["encoders"] = [
"Flux2TextEncoderStep",
"Flux2RemoteTextEncoderStep",
"Flux2VaeEncoderStep",
]
_import_structure["before_denoise"] = [
"Flux2SetTimestepsStep",
"Flux2PrepareLatentsStep",
"Flux2RoPEInputsStep",
"Flux2PrepareImageLatentsStep",
]
_import_structure["denoise"] = [
"Flux2LoopDenoiser",
"Flux2LoopAfterDenoiser",
"Flux2DenoiseLoopWrapper",
"Flux2DenoiseStep",
]
_import_structure["decoders"] = ["Flux2DecodeStep"]
_import_structure["inputs"] = [
"Flux2ProcessImagesInputStep",
"Flux2TextInputStep",
]
_import_structure["modular_blocks_flux2"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"REMOTE_AUTO_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"IMAGE_CONDITIONED_BLOCKS",
"Flux2AutoBlocks",
"Flux2AutoVaeEncoderStep",
"Flux2CoreDenoiseStep",
"Flux2VaeEncoderSequentialStep",
]
_import_structure["modular_blocks_flux2_klein"] = ["Flux2KleinAutoBlocks", "Flux2KleinBaseAutoBlocks"]
_import_structure["modular_pipeline"] = [
"Flux2KleinBaseModularPipeline",
"Flux2KleinModularPipeline",
"Flux2ModularPipeline",
"Flux2KleinModularPipeline",
"Flux2KleinBaseModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -38,10 +68,43 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .encoders import Flux2RemoteTextEncoderStep
from .modular_blocks_flux2 import Flux2AutoBlocks
from .modular_blocks_flux2_klein import Flux2KleinAutoBlocks
from .modular_blocks_flux2_klein_base import Flux2KleinBaseAutoBlocks
from .before_denoise import (
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep
from .denoise import (
Flux2DenoiseLoopWrapper,
Flux2DenoiseStep,
Flux2LoopAfterDenoiser,
Flux2LoopDenoiser,
)
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
from .modular_blocks_flux2 import (
ALL_BLOCKS,
AUTO_BLOCKS,
IMAGE_CONDITIONED_BLOCKS,
REMOTE_AUTO_BLOCKS,
TEXT2IMAGE_BLOCKS,
Flux2AutoBlocks,
Flux2AutoVaeEncoderStep,
Flux2CoreDenoiseStep,
Flux2VaeEncoderSequentialStep,
)
from .modular_blocks_flux2_klein import (
Flux2KleinAutoBlocks,
Flux2KleinBaseAutoBlocks,
)
from .modular_pipeline import Flux2KleinBaseModularPipeline, Flux2KleinModularPipeline, Flux2ModularPipeline
else:
import sys

View File

@@ -12,6 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
@@ -26,6 +30,7 @@ from .before_denoise import (
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2DenoiseStep
from .encoders import (
Flux2RemoteTextEncoderStep,
Flux2TextEncoderStep,
Flux2VaeEncoderStep,
)
@@ -38,69 +43,26 @@ from .inputs import (
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# auto_docstring
Flux2VaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
class Flux2VaeEncoderSequentialStep(SequentialPipelineBlocks):
"""
VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning.
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
model_name = "flux2"
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
block_names = ["preprocess", "encode"]
block_classes = Flux2VaeEncoderBlocks.values()
block_names = Flux2VaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses, encodes, and prepares image latents for Flux2 conditioning."
# auto_docstring
class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
"""
VAE encoder step that encodes the image inputs into their latent representations.
This is an auto pipeline block that works for image conditioning tasks.
- `Flux2VaeEncoderSequentialStep` is used when `image` is provided.
- If `image` is not provided, step will be skipped.
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
block_classes = [Flux2VaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@@ -116,76 +78,6 @@ class Flux2AutoVaeEncoderStep(AutoPipelineBlocks):
Flux2CoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
# auto_docstring
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-dev.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 4.0):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latents (`Tensor`, *optional*):
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2"
block_classes = Flux2CoreDenoiseBlocks.values()
block_names = Flux2CoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-dev."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
Flux2ImageConditionedCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
@@ -199,158 +91,116 @@ Flux2ImageConditionedCoreDenoiseBlocks = InsertableDict(
)
# auto_docstring
class Flux2ImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-dev with image conditioning.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
image_latents (`list`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 4.0):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
class Flux2CoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2"
block_classes = Flux2ImageConditionedCoreDenoiseBlocks.values()
block_names = Flux2ImageConditionedCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-dev with image conditioning."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
class Flux2AutoCoreDenoiseStep(AutoPipelineBlocks):
model_name = "flux2"
block_classes = [Flux2ImageConditionedCoreDenoiseStep, Flux2CoreDenoiseStep]
block_names = ["image_conditioned", "text2image"]
block_trigger_inputs = ["image_latents", None]
block_classes = Flux2CoreDenoiseBlocks.values()
block_names = Flux2CoreDenoiseBlocks.keys()
@property
def description(self):
return (
"Auto core denoise step that performs the denoising process for Flux2-dev."
"This is an auto pipeline block that works for text-to-image and image-conditioned generation."
" - `Flux2CoreDenoiseStep` is used for text-to-image generation.\n"
" - `Flux2ImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
"Core denoise step that performs the denoising process for Flux2-dev.\n"
" - `Flux2TextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2PrepareGuidanceStep` (prepare_guidance) prepares the guidance tensor for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2DenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2AutoCoreDenoiseStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
REMOTE_AUTO_BLOCKS = InsertableDict(
[
("text_encoder", Flux2RemoteTextEncoderStep()),
("vae_encoder", Flux2AutoVaeEncoderStep()),
("denoise", Flux2CoreDenoiseStep()),
("decode", Flux2DecodeStep()),
]
)
# auto_docstring
class Flux2AutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.
Supported workflows:
- `text2image`: requires `prompt`
- `image_conditioned`: requires `image`, `prompt`
Components:
text_encoder (`Mistral3ForConditionalGeneration`) tokenizer (`AutoProcessor`) image_processor
(`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer
(`Flux2Transformer2DModel`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`int`, *optional*, defaults to 512):
TODO: Add description.
text_encoder_out_layers (`tuple`, *optional*, defaults to (10, 20, 30)):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_latents (`list`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
timesteps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
guidance_scale (`None`, *optional*, defaults to 4.0):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
model_name = "flux2"
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
_workflow_map = {
"text2image": {"prompt": True},
"image_conditioned": {"image": True, "prompt": True},
}
@property
def description(self):
return "Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2."
return (
"Auto Modular pipeline for text-to-image and image-conditioned generation using Flux2.\n"
"- For text-to-image generation, all you need to provide is `prompt`.\n"
"- For image-conditioned generation, you need to provide `image` (list of PIL images)."
)
@property
def outputs(self):
return [
OutputParam.template("images"),
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
IMAGE_CONDITIONED_BLOCKS = InsertableDict(
[
("text_encoder", Flux2TextEncoderStep()),
("text_input", Flux2TextInputStep()),
("preprocess_images", Flux2ProcessImagesInputStep()),
("vae_encoder", Flux2VaeEncoderStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_guidance", Flux2PrepareGuidanceStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2DenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
("decode", Flux2DecodeStep()),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image_conditioned": IMAGE_CONDITIONED_BLOCKS,
"auto": AUTO_BLOCKS,
"remote": REMOTE_AUTO_BLOCKS,
}

View File

@@ -12,23 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
import PIL.Image
import torch
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2KleinBaseRoPEInputsStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2RoPEInputsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2KleinDenoiseStep
from .denoise import Flux2KleinBaseDenoiseStep, Flux2KleinDenoiseStep
from .encoders import (
Flux2KleinBaseTextEncoderStep,
Flux2KleinTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2KleinBaseTextInputStep,
Flux2ProcessImagesInputStep,
Flux2TextInputStep,
)
@@ -40,72 +47,26 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# VAE encoder
################
Flux2KleinVaeEncoderBlocks = InsertableDict(
[
("preprocess", Flux2ProcessImagesInputStep()),
("encode", Flux2VaeEncoderStep()),
]
)
# auto_docstring
class Flux2KleinVaeEncoderSequentialStep(SequentialPipelineBlocks):
"""
VAE encoder step that preprocesses and encodes the image inputs into their latent representations.
model_name = "flux2"
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
model_name = "flux2-klein"
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
block_names = ["preprocess", "encode"]
block_classes = Flux2KleinVaeEncoderBlocks.values()
block_names = Flux2KleinVaeEncoderBlocks.keys()
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
# auto_docstring
class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
"""
VAE encoder step that encodes the image inputs into their latent representations.
This is an auto pipeline block that works for image conditioning tasks.
- `Flux2KleinVaeEncoderSequentialStep` is used when `image` is provided.
- If `image` is not provided, step will be skipped.
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
model_name = "flux2-klein"
block_classes = [Flux2KleinVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@@ -125,74 +86,6 @@ class Flux2KleinAutoVaeEncoderStep(AutoPipelineBlocks):
###
Flux2KleinCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2RoPEInputsStep()),
("denoise", Flux2KleinDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
# auto_docstring
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image
generation.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latents (`Tensor`, *optional*):
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2-klein"
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model), for text-to-image generation."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
Flux2KleinImageConditionedCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2TextInputStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
@@ -205,196 +98,135 @@ Flux2KleinImageConditionedCoreDenoiseBlocks = InsertableDict(
)
# auto_docstring
class Flux2KleinImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
image_latents (`list`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
class Flux2KleinCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinImageConditionedCoreDenoiseBlocks.values()
block_names = Flux2KleinImageConditionedCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (distilled model) with image conditioning."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# auto_docstring
class Flux2KleinAutoCoreDenoiseStep(AutoPipelineBlocks):
"""
Auto core denoise step that performs the denoising process for Flux2-Klein.
This is an auto pipeline block that works for text-to-image and image-conditioned generation.
- `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.
- `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
image_latents (`list`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
timesteps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2-klein"
block_classes = [Flux2KleinImageConditionedCoreDenoiseStep, Flux2KleinCoreDenoiseStep]
block_names = ["image_conditioned", "text2image"]
block_trigger_inputs = ["image_latents", None]
block_classes = Flux2KleinCoreDenoiseBlocks.values()
block_names = Flux2KleinCoreDenoiseBlocks.keys()
@property
def description(self):
return (
"Auto core denoise step that performs the denoising process for Flux2-Klein.\n"
"This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n"
" - `Flux2KleinCoreDenoiseStep` is used for text-to-image generation.\n"
" - `Flux2KleinImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
"Core denoise step that performs the denoising process for Flux2-Klein (distilled model).\n"
" - `Flux2KleinTextInputStep` (input) standardizes the text inputs (prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2RoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids) for the denoising step.\n"
" - `Flux2KleinDenoiseStep` (denoise) iteratively denoises the latents.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model)."
return (
"Core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
" - `Flux2KleinBaseTextInputStep` (input) standardizes the text inputs (prompt_embeds + negative_prompt_embeds) for the denoising step.\n"
" - `Flux2PrepareImageLatentsStep` (prepare_image_latents) prepares the image latents and image_latent_ids for the denoising step.\n"
" - `Flux2PrepareLatentsStep` (prepare_latents) prepares the initial latents (latents) and latent_ids for the denoising step.\n"
" - `Flux2SetTimestepsStep` (set_timesteps) sets the timesteps for the denoising step.\n"
" - `Flux2KleinBaseRoPEInputsStep` (prepare_rope_inputs) prepares the RoPE inputs (txt_ids + negative_txt_ids) for the denoising step.\n"
" - `Flux2KleinBaseDenoiseStep` (denoise) iteratively denoises the latents using Classifier-Free Guidance.\n"
" - `Flux2UnpackLatentsStep` (after_denoise) unpacks the latents from the denoising step.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="latents",
type_hint=torch.Tensor,
description="The latents from the denoising step.",
)
]
###
### Auto blocks
###
# auto_docstring
class Flux2KleinAutoBlocks(SequentialPipelineBlocks):
"""
Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.
Supported workflows:
- `text2image`: requires `prompt`
- `image_conditioned`: requires `image`, `prompt`
Components:
text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) image_processor (`Flux2ImageProcessor`)
vae (`AutoencoderKLFlux2`) scheduler (`FlowMatchEulerDiscreteScheduler`) transformer
(`Flux2Transformer2DModel`)
Configs:
is_distilled (default: True)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`int`, *optional*, defaults to 512):
TODO: Add description.
text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_latents (`list`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
timesteps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
model_name = "flux2-klein"
block_classes = [
Flux2KleinTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinAutoCoreDenoiseStep(),
Flux2KleinCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image_conditioned": {"image": True, "prompt": True},
}
@property
def description(self):
return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein."
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein.\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam.template("images"),
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
model_name = "flux2-klein"
block_classes = [
Flux2KleinBaseTextEncoderStep(),
Flux2KleinAutoVaeEncoderStep(),
Flux2KleinBaseCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self):
return (
"Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).\n"
+ " - for image-conditioned generation, you need to provide `image` (list of PIL images).\n"
+ " - for text-to-image generation, all you need to provide is `prompt`.\n"
)
@property
def outputs(self):
return [
OutputParam(
name="images",
type_hint=List[PIL.Image.Image],
description="The images from the decoding step.",
)
]

View File

@@ -1,413 +0,0 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict, OutputParam
from .before_denoise import (
Flux2KleinBaseRoPEInputsStep,
Flux2PrepareImageLatentsStep,
Flux2PrepareLatentsStep,
Flux2SetTimestepsStep,
)
from .decoders import Flux2DecodeStep, Flux2UnpackLatentsStep
from .denoise import Flux2KleinBaseDenoiseStep
from .encoders import (
Flux2KleinBaseTextEncoderStep,
Flux2VaeEncoderStep,
)
from .inputs import (
Flux2KleinBaseTextInputStep,
Flux2ProcessImagesInputStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
################
# VAE encoder
################
# auto_docstring
class Flux2KleinBaseVaeEncoderSequentialStep(SequentialPipelineBlocks):
"""
VAE encoder step that preprocesses and encodes the image inputs into their latent representations.
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
model_name = "flux2"
block_classes = [Flux2ProcessImagesInputStep(), Flux2VaeEncoderStep()]
block_names = ["preprocess", "encode"]
@property
def description(self) -> str:
return "VAE encoder step that preprocesses and encodes the image inputs into their latent representations."
# auto_docstring
class Flux2KleinBaseAutoVaeEncoderStep(AutoPipelineBlocks):
"""
VAE encoder step that encodes the image inputs into their latent representations.
This is an auto pipeline block that works for image conditioning tasks.
- `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.
- If `image` is not provided, step will be skipped.
Components:
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`)
Inputs:
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
condition_images (`list`):
TODO: Add description.
image_latents (`list`):
List of latent representations for each reference image
"""
block_classes = [Flux2KleinBaseVaeEncoderSequentialStep]
block_names = ["img_conditioning"]
block_trigger_inputs = ["image"]
@property
def description(self):
return (
"VAE encoder step that encodes the image inputs into their latent representations.\n"
"This is an auto pipeline block that works for image conditioning tasks.\n"
" - `Flux2KleinBaseVaeEncoderSequentialStep` is used when `image` is provided.\n"
" - If `image` is not provided, step will be skipped."
)
###
### Core denoise
###
Flux2KleinBaseCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
# auto_docstring
class Flux2KleinBaseCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
(`ClassifierFreeGuidance`)
Configs:
is_distilled (default: False)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latents (`Tensor`, *optional*):
Packed image latents for conditioning. Shape: (B, img_seq_len, C)
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2-klein"
block_classes = Flux2KleinBaseCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model), for text-to-image generation."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
Flux2KleinBaseImageConditionedCoreDenoiseBlocks = InsertableDict(
[
("input", Flux2KleinBaseTextInputStep()),
("prepare_latents", Flux2PrepareLatentsStep()),
("prepare_image_latents", Flux2PrepareImageLatentsStep()),
("set_timesteps", Flux2SetTimestepsStep()),
("prepare_rope_inputs", Flux2KleinBaseRoPEInputsStep()),
("denoise", Flux2KleinBaseDenoiseStep()),
("after_denoise", Flux2UnpackLatentsStep()),
]
)
# auto_docstring
class Flux2KleinBaseImageConditionedCoreDenoiseStep(SequentialPipelineBlocks):
"""
Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
(`ClassifierFreeGuidance`)
Configs:
is_distilled (default: False)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
image_latents (`list`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2-klein"
block_classes = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.values()
block_names = Flux2KleinBaseImageConditionedCoreDenoiseBlocks.keys()
@property
def description(self):
return "Core denoise step that performs the denoising process for Flux2-Klein (base model) with image conditioning."
@property
def outputs(self):
return [
OutputParam.template("latents"),
]
# auto_docstring
class Flux2KleinBaseAutoCoreDenoiseStep(AutoPipelineBlocks):
"""
Auto core denoise step that performs the denoising process for Flux2-Klein (base model).
This is an auto pipeline block that works for text-to-image and image-conditioned generation.
- `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.
- `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.
Components:
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`) guider
(`ClassifierFreeGuidance`)
Configs:
is_distilled (default: False)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
image_latents (`list`, *optional*):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
timesteps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "flux2-klein"
block_classes = [Flux2KleinBaseImageConditionedCoreDenoiseStep, Flux2KleinBaseCoreDenoiseStep]
block_names = ["image_conditioned", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self):
return (
"Auto core denoise step that performs the denoising process for Flux2-Klein (base model).\n"
"This is an auto pipeline block that works for text-to-image and image-conditioned generation.\n"
" - `Flux2KleinBaseCoreDenoiseStep` is used for text-to-image generation.\n"
" - `Flux2KleinBaseImageConditionedCoreDenoiseStep` is used for image-conditioned generation.\n"
)
###
### Auto blocks
###
# auto_docstring
class Flux2KleinBaseAutoBlocks(SequentialPipelineBlocks):
"""
Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model).
Supported workflows:
- `text2image`: requires `prompt`
- `image_conditioned`: requires `image`, `prompt`
Components:
text_encoder (`Qwen3ForCausalLM`) tokenizer (`Qwen2TokenizerFast`) guider (`ClassifierFreeGuidance`)
image_processor (`Flux2ImageProcessor`) vae (`AutoencoderKLFlux2`) scheduler
(`FlowMatchEulerDiscreteScheduler`) transformer (`Flux2Transformer2DModel`)
Configs:
is_distilled (default: False)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`int`, *optional*, defaults to 512):
TODO: Add description.
text_encoder_out_layers (`tuple`, *optional*, defaults to (9, 18, 27)):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
image_latents (`list`, *optional*):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
timesteps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
joint_attention_kwargs (`None`, *optional*):
TODO: Add description.
image_latent_ids (`Tensor`, *optional*):
Position IDs for image latents. Shape: (B, img_seq_len, 4)
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
model_name = "flux2-klein"
block_classes = [
Flux2KleinBaseTextEncoderStep(),
Flux2KleinBaseAutoVaeEncoderStep(),
Flux2KleinBaseAutoCoreDenoiseStep(),
Flux2DecodeStep(),
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image_conditioned": {"image": True, "prompt": True},
}
@property
def description(self):
return "Auto blocks that perform the text-to-image and image-conditioned generation using Flux2-Klein (base model)."
@property
def outputs(self):
return [
OutputParam.template("images"),
]

View File

@@ -14,7 +14,6 @@
import importlib
import inspect
import os
import sys
import traceback
import warnings
from collections import OrderedDict
@@ -29,16 +28,10 @@ from tqdm.auto import tqdm
from typing_extensions import Self
from ..configuration_utils import ConfigMixin, FrozenDict
from ..pipelines.pipeline_loading_utils import (
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_unwrap_model,
simple_get_class_obj,
)
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
MODULAR_MODEL_CARD_TEMPLATE,
@@ -47,12 +40,8 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
combine_inputs,
combine_outputs,
format_components,
format_configs,
format_workflow,
generate_modular_model_card_content,
make_doc_string,
)
@@ -298,8 +287,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: dict[str, str] | None = None
_workflow_map = None
@classmethod
def _get_signature_keys(cls, obj):
@@ -355,35 +342,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
def outputs(self) -> list[OutputParam]:
return self._get_outputs()
# currentlyonly ConditionalPipelineBlocks and SequentialPipelineBlocks support `get_execution_blocks`
def get_execution_blocks(self, **kwargs):
"""
Get the block(s) that would execute given the inputs. Must be implemented by subclasses that support
conditional block selection.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
"""
raise NotImplementedError(f"`get_execution_blocks` is not implemented for {self.__class__.__name__}")
# currently only SequentialPipelineBlocks support workflows
@property
def available_workflows(self):
"""
Returns a list of available workflow names. Must be implemented by subclasses that define `_workflow_map`.
"""
raise NotImplementedError(f"`available_workflows` is not implemented for {self.__class__.__name__}")
def get_workflow(self, workflow_name: str):
"""
Get the execution blocks for a specific workflow. Must be implemented by subclasses that define
`_workflow_map`.
Args:
workflow_name: Name of the workflow to retrieve.
"""
raise NotImplementedError(f"`get_workflow` is not implemented for {self.__class__.__name__}")
@classmethod
def from_pretrained(
cls,
@@ -413,9 +371,6 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -440,13 +395,8 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -530,6 +480,72 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
if current_value is not param: # Using identity comparison to check if object was modified
state.set(param_name, param, input_param.kwargs_type)
@staticmethod
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if
current default value is None and new default value is not None. Warns if multiple non-None default values
exist for the same input.
Args:
named_input_lists: list of tuples containing (block_name, input_param_list) pairs
Returns:
list[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
@staticmethod
def combine_outputs(*named_output_lists: list[tuple[str, list[OutputParam]]]) -> list[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: list of tuples containing (block_name, output_param_list) pairs
Returns:
list[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
@property
def input_names(self) -> list[str]:
return [input_param.name for input_param in self.inputs if input_param.name is not None]
@@ -561,8 +577,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
class ConditionalPipelineBlocks(ModularPipelineBlocks):
"""
A Pipeline Blocks that conditionally selects a block to run based on the inputs. Subclasses must implement the
`select_block` method to define the logic for selecting the block. Currently, we only support selection logic based
on the presence or absence of inputs (i.e., whether they are `None` or not)
`select_block` method to define the logic for selecting the block.
This class inherits from [`ModularPipelineBlocks`]. Check the superclass documentation for the generic methods the
library implements for all the pipeline blocks (such as loading or saving etc.)
@@ -570,20 +585,15 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
> [!WARNING] > This is an experimental feature and is likely to change in the future.
Attributes:
block_classes: List of block classes to be used. Must have the same length as `block_names`.
block_names: List of names for each block. Must have the same length as `block_classes`.
block_trigger_inputs: List of input names that `select_block()` uses to determine which block to run.
For `ConditionalPipelineBlocks`, this does not need to correspond to `block_names` and `block_classes`. For
`AutoPipelineBlocks`, this must have the same length as `block_names` and `block_classes`, where each
element specifies the trigger input for the corresponding block.
default_block_name: Name of the default block to run when no trigger inputs match.
If None, this block can be skipped entirely when no trigger inputs are provided.
block_classes: List of block classes to be used
block_names: List of prefixes for each block
block_trigger_inputs: List of input names that select_block() uses to determine which block to run
"""
block_classes = []
block_names = []
block_trigger_inputs = []
default_block_name = None
default_block_name = None # name of the default block if no trigger inputs are provided, if None, this block can be skipped if no trigger inputs are provided
def __init__(self):
sub_blocks = InsertableDict()
@@ -647,7 +657,7 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def inputs(self) -> list[tuple[str, Any]]:
named_inputs = [(name, block.inputs) for name, block in self.sub_blocks.items()]
combined_inputs = combine_inputs(*named_inputs)
combined_inputs = self.combine_inputs(*named_inputs)
# mark Required inputs only if that input is required by all the blocks
for input_param in combined_inputs:
if input_param.name in self.required_inputs:
@@ -659,25 +669,15 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> list[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
@property
def outputs(self) -> list[str]:
named_outputs = [(name, block.outputs) for name, block in self.sub_blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
# used for `__repr__`
def _get_trigger_inputs(self) -> set:
"""
Returns a set of all unique trigger input values found in this block and nested blocks.
@@ -706,16 +706,16 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
return all_triggers
@property
def trigger_inputs(self):
"""All trigger inputs including from nested blocks."""
return self._get_trigger_inputs()
def select_block(self, **kwargs) -> str | None:
"""
Select the block to run based on the trigger inputs. Subclasses must implement this method to define the logic
for selecting the block.
Note: When trigger inputs include intermediate outputs from earlier blocks, the selection logic should only
depend on the presence or absence of the input (i.e., whether it is None or not), not on its actual value. This
is because `get_execution_blocks()` resolves conditions statically by propagating intermediate output names
without their runtime values.
Args:
**kwargs: Trigger input names and their values from the state.
@@ -750,39 +750,6 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
logger.error(error_msg)
raise
def get_execution_blocks(self, **kwargs) -> ModularPipelineBlocks | None:
"""
Get the block(s) that would execute given the inputs.
Recursively resolves nested ConditionalPipelineBlocks until reaching either:
- A leaf block (no sub_blocks or LoopSequentialPipelineBlocks) → returns single `ModularPipelineBlocks`
- A `SequentialPipelineBlocks` → delegates to its `get_execution_blocks()` which returns
a `SequentialPipelineBlocks` containing the resolved execution blocks
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Returns:
- `ModularPipelineBlocks`: A leaf block or resolved `SequentialPipelineBlocks`
- `None`: If this block would be skipped (no trigger matched and no default)
"""
trigger_kwargs = {name: kwargs.get(name) for name in self.block_trigger_inputs if name is not None}
block_name = self.select_block(**trigger_kwargs)
if block_name is None:
block_name = self.default_block_name
if block_name is None:
return None
block = self.sub_blocks[block_name]
# Recursively resolve until we hit a leaf block
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
return block.get_execution_blocks(**kwargs)
return block
def __repr__(self):
class_name = self.__class__.__name__
base_class = self.__class__.__bases__[0].__name__
@@ -790,11 +757,11 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self._get_trigger_inputs():
if self.trigger_inputs:
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {sorted(self._get_trigger_inputs())}\n"
header += f" Trigger Inputs: {sorted(self.trigger_inputs)}\n"
header += " " + "=" * 100 + "\n\n"
# Format description with proper indentation
@@ -861,56 +828,24 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
class AutoPipelineBlocks(ConditionalPipelineBlocks):
"""
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
This is a specialized version of `ConditionalPipelineBlocks` where:
- Each block has one corresponding trigger input (1:1 mapping)
- Block selection is automatic: the first block whose trigger input is present gets selected
- `block_trigger_inputs` must have the same length as `block_names` and `block_classes`
- Use `None` in `block_trigger_inputs` to specify the default block, i.e the block that will run if no trigger
inputs are present
Attributes:
block_classes:
List of block classes to be used. Must have the same length as `block_names` and
`block_trigger_inputs`.
block_names:
List of names for each block. Must have the same length as `block_classes` and `block_trigger_inputs`.
block_trigger_inputs:
List of input names where each element specifies the trigger input for the corresponding block. Use
`None` to mark the default block.
Example:
```python
class MyAutoBlock(AutoPipelineBlocks):
block_classes = [InpaintEncoderBlock, ImageEncoderBlock, TextEncoderBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask_image", "image", None] # text2img is the default
```
With this definition:
- As long as `mask_image` is provided, "inpaint" block runs (regardless of `image` being provided or not)
- If `mask_image` is not provided but `image` is provided, "img2img" block runs
- Otherwise, "text2img" block runs (default, trigger is `None`)
A Pipeline Blocks that automatically selects a block to run based on the presence of trigger inputs.
"""
def __init__(self):
super().__init__()
if self.default_block_name is not None:
raise ValueError(
f"In {self.__class__.__name__}, do not set `default_block_name` for AutoPipelineBlocks. "
f"Use `None` in `block_trigger_inputs` to specify the default block."
)
if not (len(self.block_classes) == len(self.block_names) == len(self.block_trigger_inputs)):
raise ValueError(
f"In {self.__class__.__name__}, the number of block_classes, block_names, and block_trigger_inputs must be the same."
)
@property
def default_block_name(self) -> str | None:
"""Derive default_block_name from block_trigger_inputs (None entry)."""
if None in self.block_trigger_inputs:
idx = self.block_trigger_inputs.index(None)
self.default_block_name = self.block_names[idx]
return self.block_names[idx]
return None
def select_block(self, **kwargs) -> str | None:
"""Select block based on which trigger input is present (not None)."""
@@ -964,29 +899,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs.append(config)
return expected_configs
@property
def available_workflows(self):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
return list(self._workflow_map.keys())
def get_workflow(self, workflow_name: str):
if self._workflow_map is None:
raise NotImplementedError(
f"workflows is not supported because _workflow_map is not set for {self.__class__.__name__}"
)
if workflow_name not in self._workflow_map:
raise ValueError(f"Workflow {workflow_name} not found in {self.__class__.__name__}")
trigger_inputs = self._workflow_map[workflow_name]
workflow_blocks = self.get_execution_blocks(**trigger_inputs)
return workflow_blocks
@classmethod
def from_blocks_dict(
cls, blocks_dict: dict[str, Any], description: str | None = None
@@ -1082,7 +994,7 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
# filter out them here so they do not end up as intermediate_outputs
if name not in inp_names:
named_outputs.append((name, block.intermediate_outputs))
combined_outputs = combine_outputs(*named_outputs)
combined_outputs = self.combine_outputs(*named_outputs)
return combined_outputs
# YiYi TODO: I think we can remove the outputs property
@@ -1106,7 +1018,6 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
raise
return pipeline, state
# used for `__repr__`
def _get_trigger_inputs(self):
"""
Returns a set of all unique trigger input values found in the blocks.
@@ -1130,56 +1041,89 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
return fn_recursive_get_trigger(self.sub_blocks)
def get_execution_blocks(self, **kwargs) -> "SequentialPipelineBlocks":
@property
def trigger_inputs(self):
return self._get_trigger_inputs()
def _traverse_trigger_blocks(self, active_inputs):
"""
Get the blocks that would execute given the specified inputs.
As the traversal walks through sequential blocks, intermediate outputs from resolved blocks are added to the
active inputs. This means conditional blocks that depend on intermediates (e.g., "run img2img if image_latents
is present") will resolve correctly, as long as the condition is based on presence/absence (None or not None),
not on the actual value.
Traverse blocks and select which ones would run given the active inputs.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
active_inputs: Dict of input names to values that are "present"
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
OrderedDict of block_name -> block that would execute
"""
# Copy kwargs so we can add outputs as we traverse
active_inputs = dict(kwargs)
def fn_recursive_traverse(block, block_name, active_inputs):
result_blocks = OrderedDict()
# ConditionalPipelineBlocks (includes AutoPipelineBlocks)
if isinstance(block, ConditionalPipelineBlocks):
block = block.get_execution_blocks(**active_inputs)
if block is None:
trigger_kwargs = {name: active_inputs.get(name) for name in block.block_trigger_inputs}
selected_block_name = block.select_block(**trigger_kwargs)
if selected_block_name is None:
selected_block_name = block.default_block_name
if selected_block_name is None:
return result_blocks
# Has sub_blocks (SequentialPipelineBlocks/ConditionalPipelineBlocks)
if block.sub_blocks and not isinstance(block, LoopSequentialPipelineBlocks):
selected_block = block.sub_blocks[selected_block_name]
if selected_block.sub_blocks:
result_blocks.update(fn_recursive_traverse(selected_block, block_name, active_inputs))
else:
result_blocks[block_name] = selected_block
if hasattr(selected_block, "outputs"):
for out in selected_block.outputs:
active_inputs[out.name] = True
return result_blocks
# SequentialPipelineBlocks or LoopSequentialPipelineBlocks
if block.sub_blocks:
for sub_block_name, sub_block in block.sub_blocks.items():
nested_blocks = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
nested_blocks = {f"{block_name}.{k}": v for k, v in nested_blocks.items()}
result_blocks.update(nested_blocks)
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_inputs)
blocks_to_update = {f"{block_name}.{k}": v for k, v in blocks_to_update.items()}
result_blocks.update(blocks_to_update)
else:
# Leaf block: single ModularPipelineBlocks or LoopSequentialPipelineBlocks
result_blocks[block_name] = block
# Add outputs to active_inputs so subsequent blocks can use them as triggers
if hasattr(block, "intermediate_outputs"):
for out in block.intermediate_outputs:
if hasattr(block, "outputs"):
for out in block.outputs:
active_inputs[out.name] = True
return result_blocks
all_blocks = OrderedDict()
for block_name, block in self.sub_blocks.items():
nested_blocks = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(nested_blocks)
blocks_to_update = fn_recursive_traverse(block, block_name, active_inputs)
all_blocks.update(blocks_to_update)
return all_blocks
return SequentialPipelineBlocks.from_blocks_dict(all_blocks)
def get_execution_blocks(self, **kwargs):
"""
Get the blocks that would execute given the specified inputs.
Args:
**kwargs: Input names and values. Only trigger inputs affect block selection.
Pass any inputs that would be non-None at runtime.
Returns:
SequentialPipelineBlocks containing only the blocks that would execute
Example:
# Get blocks for inpainting workflow blocks = pipeline.get_execution_blocks(prompt="a cat", mask=mask,
image=image)
# Get blocks for text2image workflow blocks = pipeline.get_execution_blocks(prompt="a cat")
"""
# Filter out None values
active_inputs = {k: v for k, v in kwargs.items() if v is not None}
blocks_triggered = self._traverse_trigger_blocks(active_inputs)
return SequentialPipelineBlocks.from_blocks_dict(blocks_triggered)
def __repr__(self):
class_name = self.__class__.__name__
@@ -1188,23 +1132,18 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
f"{class_name}(\n Class: {base_class}\n" if base_class and base_class != "object" else f"{class_name}(\n"
)
if self._workflow_map is None and self._get_trigger_inputs():
if self.trigger_inputs:
header += "\n"
header += " " + "=" * 100 + "\n"
header += " This pipeline contains blocks that are selected at runtime based on inputs.\n"
header += f" Trigger Inputs: {[inp for inp in self._get_trigger_inputs() if inp is not None]}\n"
header += f" Trigger Inputs: {[inp for inp in self.trigger_inputs if inp is not None]}\n"
# Get first trigger input as example
example_input = next(t for t in self._get_trigger_inputs() if t is not None)
example_input = next(t for t in self.trigger_inputs if t is not None)
header += f" Use `get_execution_blocks()` to see selected blocks (e.g. `get_execution_blocks({example_input}=...)`).\n"
header += " " + "=" * 100 + "\n\n"
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
# Format description with proper indentation
desc_lines = description.split("\n")
desc_lines = self.description.split("\n")
desc = []
# First line with "Description:" label
desc.append(f" Description: {desc_lines[0]}")
@@ -1252,28 +1191,15 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
@property
def doc(self):
description = self.description
if self._workflow_map is not None:
workflow_str = format_workflow(self._workflow_map)
description = f"{self.description}\n\n{workflow_str}"
return make_doc_string(
self.inputs,
self.outputs,
description=description,
self.description,
class_name=self.__class__.__name__,
expected_components=self.expected_components,
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
@@ -1401,7 +1327,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
@property
def intermediate_outputs(self) -> list[str]:
named_outputs = [(name, block.intermediate_outputs) for name, block in self.sub_blocks.items()]
combined_outputs = combine_outputs(*named_outputs)
combined_outputs = self.combine_outputs(*named_outputs)
for output in self.loop_intermediate_outputs:
if output.name not in {output.name for output in combined_outputs}:
combined_outputs.append(output)
@@ -1412,15 +1338,6 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def outputs(self) -> list[str]:
return next(reversed(self.sub_blocks.values())).intermediate_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block in zip(self.block_names, self.block_classes):
@@ -1676,14 +1593,7 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
blocks_class_name = self.default_blocks_name
if blocks_class_name is not None:
diffusers_module = importlib.import_module("diffusers")
blocks_class = getattr(diffusers_module, blocks_class_name, None)
# If the blocks_class is not found or is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict) with empty block_classes
# fall back to default_blocks_name
if blocks_class is None or not blocks_class.block_classes:
blocks_class_name = self.default_blocks_name
blocks_class = getattr(diffusers_module, blocks_class_name)
if blocks_class is not None:
blocks_class = getattr(diffusers_module, blocks_class_name)
blocks = blocks_class()
else:
logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}")
@@ -1743,8 +1653,6 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
self._pretrained_model_name_or_path = pretrained_model_name_or_path
@property
def default_call_parameters(self) -> dict[str, Any]:
"""
@@ -1871,136 +1779,44 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def save_pretrained(
self,
save_directory: str | os.PathLike,
safe_serialization: bool = True,
variant: str | None = None,
max_shard_size: int | str | None = None,
push_to_hub: bool = False,
**kwargs,
):
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
"""
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
[`~ModularPipeline.from_pretrained`] class method.
Save the pipeline to a directory. It does not save components, you need to save them separately.
Args:
save_directory (`str` or `os.PathLike`):
Directory to save the pipeline to. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `None`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the pipeline to the Hugging Face model hub after saving it.
**kwargs: Additional keyword arguments:
- `overwrite_modular_index` (`bool`, *optional*, defaults to `False`):
When saving a Modular Pipeline, its components in `modular_model_index.json` may reference repos
different from the destination repo. Setting this to `True` updates all component references in
`modular_model_index.json` so they point to the repo specified by `repo_id`.
- `repo_id` (`str`, *optional*):
The repository ID to push the pipeline to. Defaults to the last component of `save_directory`.
- `commit_message` (`str`, *optional*):
Commit message for the push to hub operation.
- `private` (`bool`, *optional*):
Whether the repository should be private.
- `create_pr` (`bool`, *optional*, defaults to `False`):
Whether to create a pull request instead of pushing directly.
- `token` (`str`, *optional*):
The Hugging Face token to use for authentication.
Path to the directory where the pipeline will be saved.
push_to_hub (`bool`, optional):
Whether to push the pipeline to the huggingface hub.
**kwargs: Additional arguments passed to `save_config()` method
"""
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
update_model_card = kwargs.pop("update_model_card", False)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
for component_name, component_spec in self._component_specs.items():
if component_spec.default_creation_method != "from_pretrained":
continue
component = getattr(self, component_name, None)
if component is None:
continue
model_cls = component.__class__
if is_compiled_module(component):
component = _unwrap_model(component)
model_cls = component.__class__
save_method_name = None
for library_name, library_classes in LOADABLE_CLASSES.items():
if library_name in sys.modules:
library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
)
continue
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
if save_method_name is None:
logger.warning(f"self.{component_name}={component} of type {type(component)} cannot be saved.")
continue
save_method = getattr(component, save_method_name)
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
if save_method_accept_max_shard_size and max_shard_size is not None:
save_kwargs["max_shard_size"] = max_shard_size
component_save_path = os.path.join(save_directory, component_name)
save_method(component_save_path, **save_kwargs)
if component_name not in self.config:
continue
has_no_load_id = not hasattr(component, "_diffusers_load_id") or component._diffusers_load_id == "null"
if overwrite_modular_index or has_no_load_id:
library, class_name, component_spec_dict = self.config[component_name]
component_spec_dict["pretrained_model_name_or_path"] = repo_id if push_to_hub else save_directory
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self.save_config(save_directory=save_directory)
if push_to_hub:
# Generate modular pipeline card content
card_content = generate_modular_model_card_content(self.blocks)
# Create a new empty model card and eventually tag it
model_card = load_or_create_model_card(
repo_id,
token=token,
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
update_model_card=update_model_card,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])
model_card.save(os.path.join(save_directory, "README.md"))
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
self.save_config(save_directory=save_directory)
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
@@ -2251,30 +2067,58 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
- the `config` dict, which will be saved as `modular_model_index.json` during `save_pretrained`
Args:
**kwargs: Component objects or configuration values to update:
- Component objects: Models loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()`
are automatically tagged with loading information. ConfigMixin objects without weights (e.g.,
schedulers, guiders) can be passed directly.
- Configuration values: Simple values to update configuration settings
(e.g., `requires_safety_checker=False`)
**kwargs: Component objects, ComponentSpec objects, or configuration values to update:
- Component objects: Only supports components we can extract specs using
`ComponentSpec.from_component()` method i.e. components created with ComponentSpec.load() or
ConfigMixin subclasses that aren't nn.Modules (e.g., `unet=new_unet, text_encoder=new_encoder`)
- ComponentSpec objects: Only supports default_creation_method == "from_config", will call create()
method to create a new component (e.g., `guider=ComponentSpec(name="guider",
type_hint=ClassifierFreeGuidance, config={...}, default_creation_method="from_config")`)
- Configuration values: Simple values to update configuration settings (e.g.,
`requires_safety_checker=False`)
Raises:
ValueError: If a component object is not supported in ComponentSpec.from_component() method:
- nn.Module components without a valid `_diffusers_load_id` attribute
- Non-ConfigMixin components without a valid `_diffusers_load_id` attribute
Examples:
```python
# Update pre-trained model
# Update multiple components at once
pipeline.update_components(unet=new_unet_model, text_encoder=new_text_encoder)
# Update configuration values
pipeline.update_components(requires_safety_checker=False)
# Update both components and configs together
pipeline.update_components(unet=new_unet_model, requires_safety_checker=False)
# Update with ComponentSpec objects (from_config only)
pipeline.update_components(
guider=ComponentSpec(
name="guider",
type_hint=ClassifierFreeGuidance,
config={"guidance_scale": 5.0},
default_creation_method="from_config",
)
)
```
Notes:
- Components loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()` will have
loading specs preserved for serialization. Custom or locally loaded components without Hub references will
have their `modular_model_index.json` entries updated automatically during `save_pretrained()`.
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
- Components with trained weights must be created using ComponentSpec.load(). If the component has not been
shared in huggingface hub and you don't have loading specs, you can upload it using `push_to_hub()`
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly
- ComponentSpec objects with default_creation_method="from_pretrained" are not supported in
update_components()
"""
passed_components = {k: kwargs.pop(k) for k in self._component_specs if k in kwargs}
# extract component_specs_updates & config_specs_updates from `specs`
passed_component_specs = {
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and isinstance(kwargs[k], ComponentSpec)
}
passed_components = {
k: kwargs.pop(k) for k in self._component_specs if k in kwargs and not isinstance(kwargs[k], ComponentSpec)
}
passed_config_values = {k: kwargs.pop(k) for k in self._config_specs if k in kwargs}
for name, component in passed_components.items():
@@ -2292,10 +2136,13 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
new_component_spec = current_component_spec
if hasattr(self, name) and getattr(self, name) is not None:
logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
elif (
current_component_spec.default_creation_method == "from_pretrained"
and getattr(component, "_diffusers_load_id", None) is None
elif current_component_spec.default_creation_method == "from_pretrained" and not (
hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
):
logger.warning(
f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
)
new_component_spec = ComponentSpec(name=name, type_hint=type(component))
else:
new_component_spec = ComponentSpec.from_component(name, component)
@@ -2310,14 +2157,33 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
if len(kwargs) > 0:
logger.warning(f"Unexpected keyword arguments, will be ignored: {kwargs.keys()}")
self.register_components(**passed_components)
created_components = {}
for name, component_spec in passed_component_specs.items():
if component_spec.default_creation_method == "from_pretrained":
raise ValueError(
"ComponentSpec object with default_creation_method == 'from_pretrained' is not supported in update_components() method"
)
created_components[name] = component_spec.create()
current_component_spec = self._component_specs[name]
# warn if type changed
if current_component_spec.type_hint is not None and not isinstance(
created_components[name], current_component_spec.type_hint
):
logger.info(
f"ModularPipeline.update_components: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}"
)
# update _component_specs based on the user passed component_spec
self._component_specs[name] = component_spec
self.register_components(**passed_components, **created_components)
config_to_register = {}
for name, new_value in passed_config_values.items():
# e.g. requires_aesthetics_score = False
self._config_specs[name].default = new_value
config_to_register[name] = new_value
self.register_to_config(**config_to_register)
# YiYi TODO: support map for additional from_pretrained kwargs
def load_components(self, names: list[str] | str | None = None, **kwargs):
"""
Load selected components from specs.
@@ -2368,49 +2234,17 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
elif "default" in value:
# check if the default is specified
component_load_kwargs[key] = value["default"]
# Only pass trust_remote_code to components from the same repo as the pipeline.
# When a user passes trust_remote_code=True, they intend to trust code from the
# pipeline's repo, not from external repos referenced in modular_model_index.json.
trust_remote_code_stripped = False
if (
"trust_remote_code" in component_load_kwargs
and self._pretrained_model_name_or_path is not None
and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path
):
component_load_kwargs.pop("trust_remote_code")
trust_remote_code_stripped = True
if not spec.pretrained_model_name_or_path:
logger.info(f"Skipping component `{name}`: no pretrained model path specified.")
continue
try:
components_to_register[name] = spec.load(**component_load_kwargs)
except Exception:
tb = traceback.format_exc()
if trust_remote_code_stripped and "trust_remote_code" in tb:
warning_msg = (
f"Failed to load component `{name}` from external repository "
f"`{spec.pretrained_model_name_or_path}`.\n\n"
f"`trust_remote_code=True` was not forwarded to `{name}` because it comes from "
f"a different repository than the pipeline (`{self._pretrained_model_name_or_path}`). "
f"For safety, `trust_remote_code` is only forwarded to components from the same "
f"repository as the pipeline.\n\n"
f"You need to load this component manually with `trust_remote_code=True` and pass it "
f"to the pipeline via `pipe.update_components()`. For example, if it is a custom model:\n\n"
f' {name} = AutoModel.from_pretrained("{spec.pretrained_model_name_or_path}", trust_remote_code=True)\n'
f" pipe.update_components({name}={name})\n"
)
else:
warning_msg = (
f"Failed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{tb}"
)
logger.warning(warning_msg)
logger.warning(
f"\nFailed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{traceback.format_exc()}"
)
# Register all components at once
self.register_components(**components_to_register)

View File

@@ -14,7 +14,6 @@
import inspect
import re
import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from types import UnionType
@@ -22,12 +21,10 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
import PIL.Image
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -52,7 +49,11 @@ This modular pipeline is composed of the following blocks:
{components_description} {configs_section}
{io_specification_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
"""
@@ -309,12 +310,6 @@ class ComponentSpec:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
if self.type_hint is None:
try:
from diffusers import AutoModel
@@ -332,12 +327,6 @@ class ComponentSpec:
else getattr(self.type_hint, "from_pretrained")
)
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
@@ -514,10 +503,6 @@ OUTPUT_PARAM_TEMPLATES = {
"type_hint": list[PIL.Image.Image],
"description": "Generated images.",
},
"videos": {
"type_hint": list[PIL.Image.Image],
"description": "The generated videos.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Denoised latents.",
@@ -809,46 +794,6 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"] if header else []
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -942,30 +887,6 @@ def format_configs(configs, indent_level=4, max_line_length=115, add_empty_lines
return "\n".join(formatted_configs)
def format_workflow(workflow_map):
"""Format a workflow map into a readable string representation.
Args:
workflow_map: Dictionary mapping workflow names to trigger inputs
Returns:
A formatted string representing all workflows
"""
if workflow_map is None:
return ""
lines = ["Supported workflows:"]
for workflow_name, trigger_inputs in workflow_map.items():
required_inputs = [k for k, v in trigger_inputs.items() if v]
if required_inputs:
inputs_str = ", ".join(f"`{t}`" for t in required_inputs)
lines.append(f" - `{workflow_name}`: requires {inputs_str}")
else:
lines.append(f" - `{workflow_name}`: default (no additional inputs required)")
return "\n".join(lines)
def make_doc_string(
inputs,
outputs,
@@ -1022,155 +943,6 @@ def make_doc_string(
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
default value is None and new default value is not None. Warns if multiple non-None default values exist for the
same input.
Args:
named_input_lists: List of tuples containing (block_name, input_param_list) pairs
Returns:
List[InputParam]: Combined list of unique InputParam objects
"""
combined_dict = {} # name -> InputParam
value_sources = {} # name -> block_name
for block_name, inputs in named_input_lists:
for input_param in inputs:
if input_param.name is None and input_param.kwargs_type is not None:
input_name = "*_" + input_param.kwargs_type
else:
input_name = input_param.name
if input_name in combined_dict:
current_param = combined_dict[input_name]
if (
current_param.default is not None
and input_param.default is not None
and current_param.default != input_param.default
):
warnings.warn(
f"Multiple different default values found for input '{input_name}': "
f"{current_param.default} (from block '{value_sources[input_name]}') and "
f"{input_param.default} (from block '{block_name}'). Using {current_param.default}."
)
if current_param.default is None and input_param.default is not None:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
else:
combined_dict[input_name] = input_param
value_sources[input_name] = block_name
return list(combined_dict.values())
def combine_outputs(*named_output_lists: list[tuple[str, list[OutputParam]]]) -> list[OutputParam]:
"""
Combines multiple lists of OutputParam objects from different blocks. For duplicate outputs, keeps the first
occurrence of each output name.
Args:
named_output_lists: List of tuples containing (block_name, output_param_list) pairs
Returns:
List[OutputParam]: Combined list of unique OutputParam objects
"""
combined_dict = {} # name -> OutputParam
for block_name, outputs in named_output_lists:
for output_param in outputs:
if (output_param.name not in combined_dict) or (
combined_dict[output_param.name].kwargs_type is None and output_param.kwargs_type is not None
):
combined_dict[output_param.name] = output_param
return list(combined_dict.values())
def generate_modular_model_card_content(blocks) -> dict[str, Any]:
"""
Generate model card content for a modular pipeline.
@@ -1188,7 +960,8 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- io_specification_section: Input/Output specification (per-workflow or unified)
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1207,6 +980,15 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")
# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
components = getattr(blocks, "expected_components", [])
@@ -1232,76 +1014,63 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
inputs = blocks.inputs
outputs = blocks.outputs
if has_workflows:
workflow_map = blocks._workflow_map
parts = []
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
parts.append("## Workflow Input Specification\n")
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
# Per-workflow details: show trigger inputs with full param descriptions
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
wf_inputs = workflow_blocks.inputs
# Show only trigger inputs with full parameter descriptions
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
inputs_str = format_params_markdown(trigger_params, header=None)
parts.append(inputs_str if inputs_str else "No additional inputs required.")
parts.append("")
parts.append("</details>\n")
# Common Inputs & Outputs section (like non-workflow pipelines)
all_inputs = blocks.inputs
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
inputs_str = format_params_markdown(all_inputs, "Inputs")
outputs_str = format_params_markdown(all_outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1314,18 +1083,7 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1353,7 +1111,8 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"io_specification_section": io_specification_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -21,15 +21,27 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_qwenimage"] = ["QwenImageAutoBlocks"]
_import_structure["modular_blocks_qwenimage_edit"] = ["QwenImageEditAutoBlocks"]
_import_structure["modular_blocks_qwenimage_edit_plus"] = ["QwenImageEditPlusAutoBlocks"]
_import_structure["modular_blocks_qwenimage_layered"] = ["QwenImageLayeredAutoBlocks"]
_import_structure["modular_blocks_qwenimage"] = [
"AUTO_BLOCKS",
"QwenImageAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit"] = [
"EDIT_AUTO_BLOCKS",
"QwenImageEditAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_edit_plus"] = [
"EDIT_PLUS_AUTO_BLOCKS",
"QwenImageEditPlusAutoBlocks",
]
_import_structure["modular_blocks_qwenimage_layered"] = [
"LAYERED_AUTO_BLOCKS",
"QwenImageLayeredAutoBlocks",
]
_import_structure["modular_pipeline"] = [
"QwenImageEditModularPipeline",
"QwenImageEditPlusModularPipeline",
"QwenImageLayeredModularPipeline",
"QwenImageModularPipeline",
"QwenImageLayeredModularPipeline",
]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -39,10 +51,22 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_qwenimage import QwenImageAutoBlocks
from .modular_blocks_qwenimage_edit import QwenImageEditAutoBlocks
from .modular_blocks_qwenimage_edit_plus import QwenImageEditPlusAutoBlocks
from .modular_blocks_qwenimage_layered import QwenImageLayeredAutoBlocks
from .modular_blocks_qwenimage import (
AUTO_BLOCKS,
QwenImageAutoBlocks,
)
from .modular_blocks_qwenimage_edit import (
EDIT_AUTO_BLOCKS,
QwenImageEditAutoBlocks,
)
from .modular_blocks_qwenimage_edit_plus import (
EDIT_PLUS_AUTO_BLOCKS,
QwenImageEditPlusAutoBlocks,
)
from .modular_blocks_qwenimage_layered import (
LAYERED_AUTO_BLOCKS,
QwenImageLayeredAutoBlocks,
)
from .modular_pipeline import (
QwenImageEditModularPipeline,
QwenImageEditPlusModularPipeline,

View File

@@ -558,7 +558,7 @@ class QwenImageSetTimestepsStep(ModularPipelineBlocks):
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
latents (`Tensor`):
The initial random noised latents for the denoising process. Can be generated in prepare latents step.
@@ -644,7 +644,7 @@ class QwenImageLayeredSetTimestepsStep(ModularPipelineBlocks):
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
@@ -725,7 +725,7 @@ class QwenImageSetTimestepsWithStrengthStep(ModularPipelineBlocks):
Inputs:
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
latents (`Tensor`):
The latents to use for the denoising process. Can be generated in prepare latents step.
@@ -842,7 +842,7 @@ class QwenImageRoPEInputsStep(ModularPipelineBlocks):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`list`):
img_shapes (`List`):
The shapes of the images latents, used for RoPE calculation
"""
@@ -917,7 +917,7 @@ class QwenImageEditRoPEInputsStep(ModularPipelineBlocks):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`list`):
img_shapes (`List`):
The shapes of the images latents, used for RoPE calculation
"""
@@ -995,9 +995,9 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
batch_size (`int`, *optional*, defaults to 1):
Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can
be generated in input step.
image_height (`list`):
image_height (`List`):
The heights of the reference images. Can be generated in input step.
image_width (`list`):
image_width (`List`):
The widths of the reference images. Can be generated in input step.
height (`int`):
The height in pixels of the generated image.
@@ -1009,11 +1009,11 @@ class QwenImageEditPlusRoPEInputsStep(ModularPipelineBlocks):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`list`):
img_shapes (`List`):
The shapes of the image latents, used for RoPE calculation
txt_seq_lens (`list`):
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`list`):
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
"""
@@ -1123,11 +1123,11 @@ class QwenImageLayeredRoPEInputsStep(ModularPipelineBlocks):
mask for the negative text embeddings. Can be generated from text_encoder step.
Outputs:
img_shapes (`list`):
img_shapes (`List`):
The shapes of the image latents, used for RoPE calculation
txt_seq_lens (`list`):
txt_seq_lens (`List`):
The sequence lengths of the prompt embeds, used for RoPE calculation
negative_txt_seq_lens (`list`):
negative_txt_seq_lens (`List`):
The sequence lengths of the negative prompt embeds, used for RoPE calculation
additional_t_cond (`Tensor`):
The additional t cond, used for RoPE calculation
@@ -1238,7 +1238,7 @@ class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
The timesteps to use for the denoising process. Can be generated in set_timesteps step.
Outputs:
controlnet_keep (`list`):
controlnet_keep (`List`):
The controlnet keep values
"""

View File

@@ -191,7 +191,7 @@ class QwenImageDecoderStep(ModularPipelineBlocks):
step.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -268,7 +268,7 @@ class QwenImageLayeredDecoderStep(ModularPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""
@@ -366,7 +366,7 @@ class QwenImageProcessImagesOutputStep(ModularPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""
@@ -436,12 +436,12 @@ class QwenImageInpaintProcessImagesOutputStep(ModularPipelineBlocks):
the generated image tensor from decoders step
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`dict`, *optional*):
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""

View File

@@ -518,11 +518,11 @@ class QwenImageDenoiseStep(QwenImageDenoiseLoopWrapper):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
Outputs:
@@ -576,11 +576,11 @@ class QwenImageInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
The number of denoising steps.
latents (`Tensor`):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
@@ -645,13 +645,13 @@ class QwenImageControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
controlnet_keep (`list`):
controlnet_keep (`List`):
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
Outputs:
@@ -711,13 +711,13 @@ class QwenImageInpaintControlNetDenoiseStep(QwenImageDenoiseLoopWrapper):
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning. (updated in prepare_controlnet_inputs step.)
controlnet_keep (`list`):
controlnet_keep (`List`):
The controlnet keep values. Can be generated in prepare_controlnet_inputs step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
@@ -787,11 +787,11 @@ class QwenImageEditDenoiseStep(QwenImageDenoiseLoopWrapper):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
Outputs:
@@ -846,11 +846,11 @@ class QwenImageEditInpaintDenoiseStep(QwenImageDenoiseLoopWrapper):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
mask (`Tensor`):
The mask to use for the inpainting process. Can be generated in inpaint prepare latents step.
@@ -910,11 +910,11 @@ class QwenImageLayeredDenoiseStep(QwenImageDenoiseLoopWrapper):
The initial latents to use for the denoising process. Can be generated in prepare_latent step.
image_latents (`Tensor`):
image latents used to guide the image generation. Can be generated from vae_encoder step.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
img_shapes (`list`):
img_shapes (`List`):
The shape of the image latents for RoPE calculation. Can be generated in prepare_additional_inputs step.
Outputs:

View File

@@ -285,11 +285,11 @@ class QwenImageEditResizeStep(ModularPipelineBlocks):
image_resize_processor (`VaeImageProcessor`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
"""
@@ -359,13 +359,13 @@ class QwenImageLayeredResizeStep(ModularPipelineBlocks):
image_resize_processor (`VaeImageProcessor`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
"""
@@ -452,13 +452,13 @@ class QwenImageEditPlusResizeStep(ModularPipelineBlocks):
image_resize_processor (`VaeImageProcessor`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
Outputs:
resized_image (`list`):
resized_image (`List`):
Images resized to 1024x1024 target area for VAE encoding
resized_cond_image (`list`):
resized_cond_image (`List`):
Images resized to 384x384 target area for VL text encoding
"""
@@ -1058,7 +1058,7 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
Inputs:
mask_image (`Image`):
Mask image for inpainting.
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
@@ -1072,7 +1072,7 @@ class QwenImageInpaintProcessImagesInputStep(ModularPipelineBlocks):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`dict`):
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
"""
@@ -1177,7 +1177,7 @@ class QwenImageEditInpaintProcessImagesInputStep(ModularPipelineBlocks):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`dict`):
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
"""
@@ -1256,7 +1256,7 @@ class QwenImageProcessImagesInputStep(ModularPipelineBlocks):
image_processor (`VaeImageProcessor`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
@@ -1340,7 +1340,7 @@ class QwenImageEditProcessImagesInputStep(ModularPipelineBlocks):
image_processor (`VaeImageProcessor`)
Inputs:
resized_image (`list`):
resized_image (`List`):
The resized image. should be generated using a resize step
Outputs:
@@ -1412,7 +1412,7 @@ class QwenImageEditPlusProcessImagesInputStep(ModularPipelineBlocks):
image_processor (`VaeImageProcessor`)
Inputs:
resized_image (`list`):
resized_image (`List`):
The resized image. should be generated using a resize step
Outputs:

View File

@@ -496,9 +496,9 @@ class QwenImageEditPlusAdditionalInputsStep(ModularPipelineBlocks):
image latents used to guide the image generation. Can be generated from vae_encoder step.
Outputs:
image_height (`list`):
image_height (`List`):
The image heights calculated from the image latents dimension
image_width (`list`):
image_width (`List`):
The image widths calculated from the image latents dimension
height (`int`):
if not provided, updated to image height

View File

@@ -119,7 +119,7 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
Inputs:
mask_image (`Image`):
Mask image for inpainting.
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
@@ -135,7 +135,7 @@ class QwenImageInpaintVaeEncoderStep(SequentialPipelineBlocks):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`dict`):
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
image_latents (`Tensor`):
The latent representation of the input image.
@@ -164,7 +164,7 @@ class QwenImageImg2ImgVaeEncoderStep(SequentialPipelineBlocks):
image_processor (`VaeImageProcessor`) vae (`AutoencoderKLQwenImage`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
@@ -476,9 +476,9 @@ class QwenImageCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -553,11 +553,11 @@ class QwenImageInpaintCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -632,11 +632,11 @@ class QwenImageImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -712,7 +712,7 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
control_guidance_start (`float`, *optional*, defaults to 0.0):
When to start applying ControlNet.
@@ -720,7 +720,7 @@ class QwenImageControlNetCoreDenoiseStep(SequentialPipelineBlocks):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -802,7 +802,7 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
@@ -812,7 +812,7 @@ class QwenImageControlNetInpaintCoreDenoiseStep(SequentialPipelineBlocks):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -894,7 +894,7 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
@@ -904,7 +904,7 @@ class QwenImageControlNetImg2ImgCoreDenoiseStep(SequentialPipelineBlocks):
When to stop applying ControlNet.
controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
Scale for ControlNet conditioning.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -1032,7 +1032,7 @@ class QwenImageDecodeStep(SequentialPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -1061,12 +1061,12 @@ class QwenImageInpaintDecodeStep(SequentialPipelineBlocks):
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`dict`, *optional*):
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -1113,14 +1113,10 @@ AUTO_BLOCKS = InsertableDict(
class QwenImageAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `prompt`, `image`
- `inpainting`: requires `prompt`, `mask_image`, `image`
- `controlnet_text2image`: requires `prompt`, `control_image`
- `controlnet_image2image`: requires `prompt`, `image`, `control_image`
- `controlnet_inpainting`: requires `prompt`, `mask_image`, `image`, `control_image`
- for image-to-image generation, you need to provide `image`
- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.
- to run the controlnet workflow, you need to provide `control_image`
- for text-to-image generation, all you need to provide is `prompt`
Components:
text_encoder (`Qwen2_5_VLForConditionalGeneration`): The text encoder to use tokenizer (`Qwen2Tokenizer`):
@@ -1138,7 +1134,7 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
Maximum sequence length for prompt encoding.
mask_image (`Image`, *optional*):
Mask image for inpainting.
image (`Image | list`, *optional*):
image (`Union[Image, List]`, *optional*):
Reference image(s) for denoising. Can be a single image or list of images.
height (`int`, *optional*):
The height in pixels of the generated image.
@@ -1164,9 +1160,9 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -1187,12 +1183,12 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
Scale for ControlNet conditioning.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`dict`, *optional*):
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""
@@ -1201,23 +1197,15 @@ class QwenImageAutoBlocks(SequentialPipelineBlocks):
block_classes = AUTO_BLOCKS.values()
block_names = AUTO_BLOCKS.keys()
# Workflow map defines the trigger conditions for each workflow.
# How to define:
# - Only include required inputs and trigger inputs (inputs that determine which blocks run)
# - currently, only supports `True` means the workflow triggers when the input is not None
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"prompt": True, "image": True},
"inpainting": {"prompt": True, "mask_image": True, "image": True},
"controlnet_text2image": {"prompt": True, "control_image": True},
"controlnet_image2image": {"prompt": True, "image": True, "control_image": True},
"controlnet_inpainting": {"prompt": True, "mask_image": True, "image": True, "control_image": True},
}
@property
def description(self):
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage."
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using QwenImage.\n"
+ "- for image-to-image generation, you need to provide `image`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop`.\n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
@property
def outputs(self):

View File

@@ -67,7 +67,7 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
@@ -75,7 +75,7 @@ class QwenImageEditVLEncoderStep(SequentialPipelineBlocks):
The prompt or prompts not to guide the image generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
prompt_embeds (`Tensor`):
The prompt embeddings.
@@ -115,13 +115,13 @@ class QwenImageEditVaeEncoderStep(SequentialPipelineBlocks):
(`AutoencoderKLQwenImage`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
@@ -156,7 +156,7 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
(`AutoencoderKLQwenImage`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
mask_image (`Image`):
Mask image for inpainting.
@@ -166,13 +166,13 @@ class QwenImageEditInpaintVaeEncoderStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
processed_mask_image (`Tensor`):
The processed mask image
mask_overlay_kwargs (`dict`):
mask_overlay_kwargs (`Dict`):
The kwargs for the postprocess step to apply the mask overlay
image_latents (`Tensor`):
The latent representation of the input image.
@@ -450,9 +450,9 @@ class QwenImageEditCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -526,11 +526,11 @@ class QwenImageEditInpaintCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -627,7 +627,7 @@ class QwenImageEditDecodeStep(SequentialPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -656,12 +656,12 @@ class QwenImageEditInpaintDecodeStep(SequentialPipelineBlocks):
step.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`dict`, *optional*):
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -718,11 +718,6 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
- for edit inpainting, you need to provide `mask_image` and `image`, optionally you can provide
`padding_mask_crop`
Supported workflows:
- `image_conditioned`: requires `prompt`, `image`
- `image_conditioned_inpainting`: requires `prompt`, `mask_image`, `image`
Components:
image_resize_processor (`VaeImageProcessor`) text_encoder (`Qwen2_5_VLForConditionalGeneration`) processor
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`) image_mask_processor (`InpaintProcessor`) vae
@@ -730,7 +725,7 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
(`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
@@ -756,32 +751,28 @@ class QwenImageEditAutoBlocks(SequentialPipelineBlocks):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
strength (`float`, *optional*, defaults to 0.9):
Strength for img2img/inpainting.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
output_type (`str`, *optional*, defaults to pil):
Output format: 'pil', 'np', 'pt'.
mask_overlay_kwargs (`dict`, *optional*):
mask_overlay_kwargs (`Dict`, *optional*):
The kwargs for the postprocess step to apply the mask overlay. generated in
InpaintProcessImagesInputStep.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""
model_name = "qwenimage-edit"
block_classes = EDIT_AUTO_BLOCKS.values()
block_names = EDIT_AUTO_BLOCKS.keys()
_workflow_map = {
"image_conditioned": {"prompt": True, "image": True},
"image_conditioned_inpainting": {"prompt": True, "mask_image": True, "image": True},
}
@property
def description(self):

View File

@@ -58,7 +58,7 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
(`Qwen2VLProcessor`) guider (`ClassifierFreeGuidance`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
@@ -66,9 +66,9 @@ class QwenImageEditPlusVLEncoderStep(SequentialPipelineBlocks):
The prompt or prompts not to guide the image generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
Images resized to 1024x1024 target area for VAE encoding
resized_cond_image (`list`):
resized_cond_image (`List`):
Images resized to 384x384 target area for VL text encoding
prompt_embeds (`Tensor`):
The prompt embeddings.
@@ -108,15 +108,15 @@ class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks):
(`AutoencoderKLQwenImage`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
generator (`Generator`, *optional*):
Torch generator for deterministic generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
Images resized to 1024x1024 target area for VAE encoding
resized_cond_image (`list`):
resized_cond_image (`List`):
Images resized to 384x384 target area for VL text encoding
processed_image (`Tensor`):
The processed image
@@ -189,9 +189,9 @@ class QwenImageEditPlusInputStep(SequentialPipelineBlocks):
The negative prompt embeddings. (batch-expanded)
negative_prompt_embeds_mask (`Tensor`):
The negative prompt embeddings mask. (batch-expanded)
image_height (`list`):
image_height (`List`):
The image heights calculated from the image latents dimension
image_width (`list`):
image_width (`List`):
The image widths calculated from the image latents dimension
height (`int`):
if not provided, updated to image height
@@ -253,9 +253,9 @@ class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -315,7 +315,7 @@ class QwenImageEditPlusDecodeStep(SequentialPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images. (tensor output of the vae decoder.)
"""
@@ -357,7 +357,7 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
prompt (`str`):
The prompt or prompts to guide image generation.
@@ -375,9 +375,9 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
Pre-generated noisy latents for image generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -385,7 +385,7 @@ class QwenImageEditPlusAutoBlocks(SequentialPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""

View File

@@ -60,7 +60,7 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
(`Qwen2VLProcessor`) tokenizer (`Qwen2Tokenizer`): The tokenizer to use guider (`ClassifierFreeGuidance`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
@@ -74,7 +74,7 @@ class QwenImageLayeredTextEncoderStep(SequentialPipelineBlocks):
Maximum sequence length for prompt encoding.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
prompt (`str`):
The prompt or prompts to guide image generation. If not provided, updated using image caption
@@ -117,7 +117,7 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
(`AutoencoderKLQwenImage`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
@@ -125,7 +125,7 @@ class QwenImageLayeredVaeEncoderStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
Outputs:
resized_image (`list`):
resized_image (`List`):
The resized images
processed_image (`Tensor`):
The processed image
@@ -250,9 +250,9 @@ class QwenImageLayeredCoreDenoiseStep(SequentialPipelineBlocks):
Torch generator for deterministic generation.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -317,7 +317,7 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
scheduler (`FlowMatchEulerDiscreteScheduler`) transformer (`QwenImageTransformer2DModel`)
Inputs:
image (`Image | list`):
image (`Union[Image, List]`):
Reference image(s) for denoising. Can be a single image or list of images.
resolution (`int`, *optional*, defaults to 640):
The target area to resize the image to, can be 1024 or 640
@@ -339,9 +339,9 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
Number of layers to extract from the image
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps.
sigmas (`list`, *optional*):
sigmas (`List`, *optional*):
Custom sigmas for the denoising process.
attention_kwargs (`dict`, *optional*):
attention_kwargs (`Dict`, *optional*):
Additional kwargs for attention processors.
**denoiser_input_fields (`None`, *optional*):
conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
@@ -349,7 +349,7 @@ class QwenImageLayeredAutoBlocks(SequentialPipelineBlocks):
Output format: 'pil', 'np', 'pt'.
Outputs:
images (`list`):
images (`List`):
Generated images.
"""

View File

@@ -21,7 +21,21 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks"]
_import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"CONTROLNET_BLOCKS",
"IMAGE2IMAGE_BLOCKS",
"INPAINT_BLOCKS",
"IP_ADAPTER_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"StableDiffusionXLAutoBlocks",
"StableDiffusionXLAutoControlnetStep",
"StableDiffusionXLAutoDecodeStep",
"StableDiffusionXLAutoIPAdapterStep",
"StableDiffusionXLAutoVaeEncoderStep",
]
_import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -31,7 +45,23 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_stable_diffusion_xl import StableDiffusionXLAutoBlocks
from .encoders import (
StableDiffusionXLTextEncoderStep,
)
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
CONTROLNET_BLOCKS,
IMAGE2IMAGE_BLOCKS,
INPAINT_BLOCKS,
IP_ADAPTER_BLOCKS,
TEXT2IMAGE_BLOCKS,
StableDiffusionXLAutoBlocks,
StableDiffusionXLAutoControlnetStep,
StableDiffusionXLAutoDecodeStep,
StableDiffusionXLAutoIPAdapterStep,
StableDiffusionXLAutoVaeEncoderStep,
)
from .modular_pipeline import StableDiffusionXLModularPipeline
else:
import sys

View File

@@ -14,7 +14,7 @@
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
StableDiffusionXLControlNetInputStep,
StableDiffusionXLControlNetUnionInputStep,
@@ -277,161 +277,7 @@ class StableDiffusionXLCoreDenoiseStep(SequentialPipelineBlocks):
# ip-adapter, controlnet, text2img, img2img, inpainting
# auto_docstring
class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion
XL.
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `image`, `prompt`
- `inpainting`: requires `mask_image`, `image`, `prompt`
- `controlnet_text2image`: requires `control_image`, `prompt`
- `controlnet_image2image`: requires `control_image`, `image`, `prompt`
- `controlnet_inpainting`: requires `control_image`, `mask_image`, `image`, `prompt`
- `controlnet_union_text2image`: requires `control_image`, `control_mode`, `prompt`
- `controlnet_union_image2image`: requires `control_image`, `control_mode`, `image`, `prompt`
- `controlnet_union_inpainting`: requires `control_image`, `control_mode`, `mask_image`, `image`, `prompt`
- `ip_adapter_text2image`: requires `ip_adapter_image`, `prompt`
- `ip_adapter_image2image`: requires `ip_adapter_image`, `image`, `prompt`
- `ip_adapter_inpainting`: requires `ip_adapter_image`, `mask_image`, `image`, `prompt`
- `ip_adapter_controlnet_text2image`: requires `ip_adapter_image`, `control_image`, `prompt`
- `ip_adapter_controlnet_image2image`: requires `ip_adapter_image`, `control_image`, `image`, `prompt`
- `ip_adapter_controlnet_inpainting`: requires `ip_adapter_image`, `control_image`, `mask_image`, `image`,
`prompt`
- `ip_adapter_controlnet_union_text2image`: requires `ip_adapter_image`, `control_image`, `control_mode`,
`prompt`
- `ip_adapter_controlnet_union_image2image`: requires `ip_adapter_image`, `control_image`, `control_mode`,
`image`, `prompt`
- `ip_adapter_controlnet_union_inpainting`: requires `ip_adapter_image`, `control_image`, `control_mode`,
`mask_image`, `image`, `prompt`
Components:
text_encoder (`CLIPTextModel`) text_encoder_2 (`CLIPTextModelWithProjection`) tokenizer (`CLIPTokenizer`)
tokenizer_2 (`CLIPTokenizer`) guider (`ClassifierFreeGuidance`) image_encoder
(`CLIPVisionModelWithProjection`) feature_extractor (`CLIPImageProcessor`) unet (`UNet2DConditionModel`) vae
(`AutoencoderKL`) image_processor (`VaeImageProcessor`) mask_processor (`VaeImageProcessor`) scheduler
(`EulerDiscreteScheduler`) controlnet (`ControlNetUnionModel`) control_image_processor (`VaeImageProcessor`)
Configs:
force_zeros_for_empty_prompt (default: True) requires_aesthetics_score (default: False)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
prompt_2 (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
negative_prompt_2 (`None`, *optional*):
TODO: Add description.
cross_attention_kwargs (`None`, *optional*):
TODO: Add description.
clip_skip (`None`, *optional*):
TODO: Add description.
ip_adapter_image (`Image | ndarray | Tensor | list | list | list`, *optional*):
The image(s) to be used as ip adapter
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image (`None`, *optional*):
TODO: Add description.
mask_image (`None`, *optional*):
TODO: Add description.
padding_mask_crop (`None`, *optional*):
TODO: Add description.
dtype (`dtype`, *optional*):
The dtype of the model inputs
generator (`None`, *optional*):
TODO: Add description.
preprocess_kwargs (`dict | NoneType`, *optional*):
A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under
`self.image_processor` in [diffusers.image_processor.VaeImageProcessor]
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
ip_adapter_embeds (`list`, *optional*):
Pre-generated image embeddings for IP-Adapter. Can be generated from ip_adapter step.
negative_ip_adapter_embeds (`list`, *optional*):
Pre-generated negative image embeddings for IP-Adapter. Can be generated from ip_adapter step.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
denoising_end (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.3):
TODO: Add description.
denoising_start (`None`, *optional*):
TODO: Add description.
latents (`None`):
TODO: Add description.
image_latents (`Tensor`, *optional*):
The latents representing the reference image for image-to-image/inpainting generation. Can be generated
in vae_encode step.
mask (`Tensor`, *optional*):
The mask for the inpainting generation. Can be generated in vae_encode step.
masked_image_latents (`Tensor`, *optional*):
The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be
generated in vae_encode step.
original_size (`None`, *optional*):
TODO: Add description.
target_size (`None`, *optional*):
TODO: Add description.
negative_original_size (`None`, *optional*):
TODO: Add description.
negative_target_size (`None`, *optional*):
TODO: Add description.
crops_coords_top_left (`None`, *optional*, defaults to (0, 0)):
TODO: Add description.
negative_crops_coords_top_left (`None`, *optional*, defaults to (0, 0)):
TODO: Add description.
aesthetic_score (`None`, *optional*, defaults to 6.0):
TODO: Add description.
negative_aesthetic_score (`None`, *optional*, defaults to 2.0):
TODO: Add description.
control_image (`None`, *optional*):
TODO: Add description.
control_mode (`None`, *optional*):
TODO: Add description.
control_guidance_start (`None`, *optional*, defaults to 0.0):
TODO: Add description.
control_guidance_end (`None`, *optional*, defaults to 1.0):
TODO: Add description.
controlnet_conditioning_scale (`None`, *optional*, defaults to 1.0):
TODO: Add description.
guess_mode (`None`, *optional*, defaults to False):
TODO: Add description.
crops_coords (`tuple | NoneType`, *optional*):
The crop coordinates to use for preprocess/postprocess the image and mask, for inpainting task only. Can
be generated in vae_encode step.
controlnet_cond (`Tensor`, *optional*):
The control image to use for the denoising process. Can be generated in prepare_controlnet_inputs step.
conditioning_scale (`float`, *optional*):
The controlnet conditioning scale value to use for the denoising process. Can be generated in
prepare_controlnet_inputs step.
controlnet_keep (`list`, *optional*):
The controlnet keep values to use for the denoising process. Can be generated in
prepare_controlnet_inputs step.
**denoiser_input_fields (`None`, *optional*):
All conditional model inputs that need to be prepared with guider. It should contain
prompt_embeds/negative_prompt_embeds, add_time_ids/negative_add_time_ids,
pooled_prompt_embeds/negative_pooled_prompt_embeds, and ip_adapter_embeds/negative_ip_adapter_embeds
(optional).please add `kwargs_type=denoiser_input_fields` to their parameter spec (`OutputParam`) when
they are created and added to the pipeline state
eta (`None`, *optional*, defaults to 0.0):
TODO: Add description.
output_type (`None`, *optional*, defaults to pil):
TODO: Add description.
Outputs:
images (`list`):
Generated images.
"""
block_classes = [
StableDiffusionXLTextEncoderStep,
StableDiffusionXLAutoIPAdapterStep,
@@ -447,66 +293,103 @@ class StableDiffusionXLAutoBlocks(SequentialPipelineBlocks):
"decode",
]
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
"inpainting": {"mask_image": True, "image": True, "prompt": True},
"controlnet_text2image": {"control_image": True, "prompt": True},
"controlnet_image2image": {"control_image": True, "image": True, "prompt": True},
"controlnet_inpainting": {"control_image": True, "mask_image": True, "image": True, "prompt": True},
"controlnet_union_text2image": {"control_image": True, "control_mode": True, "prompt": True},
"controlnet_union_image2image": {"control_image": True, "control_mode": True, "image": True, "prompt": True},
"controlnet_union_inpainting": {
"control_image": True,
"control_mode": True,
"mask_image": True,
"image": True,
"prompt": True,
},
"ip_adapter_text2image": {"ip_adapter_image": True, "prompt": True},
"ip_adapter_image2image": {"ip_adapter_image": True, "image": True, "prompt": True},
"ip_adapter_inpainting": {"ip_adapter_image": True, "mask_image": True, "image": True, "prompt": True},
"ip_adapter_controlnet_text2image": {"ip_adapter_image": True, "control_image": True, "prompt": True},
"ip_adapter_controlnet_image2image": {
"ip_adapter_image": True,
"control_image": True,
"image": True,
"prompt": True,
},
"ip_adapter_controlnet_inpainting": {
"ip_adapter_image": True,
"control_image": True,
"mask_image": True,
"image": True,
"prompt": True,
},
"ip_adapter_controlnet_union_text2image": {
"ip_adapter_image": True,
"control_image": True,
"control_mode": True,
"prompt": True,
},
"ip_adapter_controlnet_union_image2image": {
"ip_adapter_image": True,
"control_image": True,
"control_mode": True,
"image": True,
"prompt": True,
},
"ip_adapter_controlnet_union_inpainting": {
"ip_adapter_image": True,
"control_image": True,
"control_mode": True,
"mask_image": True,
"image": True,
"prompt": True,
},
}
@property
def description(self):
return (
"Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL.\n"
+ "- for image-to-image generation, you need to provide either `image` or `image_latents`\n"
+ "- for inpainting, you need to provide `mask_image` and `image`, optionally you can provide `padding_mask_crop` \n"
+ "- to run the controlnet workflow, you need to provide `control_image`\n"
+ "- to run the controlnet_union workflow, you need to provide `control_image` and `control_mode`\n"
+ "- to run the ip_adapter workflow, you need to provide `ip_adapter_image`\n"
+ "- for text-to-image generation, all you need to provide is `prompt`"
)
# controlnet (input + denoise step)
class StableDiffusionXLAutoControlnetStep(SequentialPipelineBlocks):
block_classes = [
StableDiffusionXLAutoControlNetInputStep,
StableDiffusionXLAutoControlNetDenoiseStep,
]
block_names = ["controlnet_input", "controlnet_denoise"]
@property
def description(self):
return "Auto Modular pipeline for text-to-image, image-to-image, inpainting, and controlnet tasks using Stable Diffusion XL."
return (
"Controlnet auto step that prepare the controlnet input and denoise the latents. "
+ "It works for both controlnet and controlnet_union and supports text2img, img2img and inpainting tasks."
+ " (it should be replace at 'denoise' step)"
)
@property
def outputs(self):
return [OutputParam.template("images")]
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLSetTimestepsStep),
("prepare_latents", StableDiffusionXLPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("vae_encoder", StableDiffusionXLVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLImg2ImgPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLDenoiseStep),
("decode", StableDiffusionXLDecodeStep),
]
)
INPAINT_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("vae_encoder", StableDiffusionXLInpaintVaeEncoderStep),
("input", StableDiffusionXLInputStep),
("set_timesteps", StableDiffusionXLImg2ImgSetTimestepsStep),
("prepare_latents", StableDiffusionXLInpaintPrepareLatentsStep),
("prepare_add_cond", StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep),
("denoise", StableDiffusionXLInpaintDenoiseStep),
("decode", StableDiffusionXLInpaintDecodeStep),
]
)
CONTROLNET_BLOCKS = InsertableDict(
[
("denoise", StableDiffusionXLAutoControlnetStep),
]
)
IP_ADAPTER_BLOCKS = InsertableDict(
[
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", StableDiffusionXLTextEncoderStep),
("ip_adapter", StableDiffusionXLAutoIPAdapterStep),
("vae_encoder", StableDiffusionXLAutoVaeEncoderStep),
("denoise", StableDiffusionXLCoreDenoiseStep),
("decode", StableDiffusionXLAutoDecodeStep),
]
)
ALL_BLOCKS = {
"text2img": TEXT2IMAGE_BLOCKS,
"img2img": IMAGE2IMAGE_BLOCKS,
"inpaint": INPAINT_BLOCKS,
"controlnet": CONTROLNET_BLOCKS,
"ip_adapter": IP_ADAPTER_BLOCKS,
"auto": AUTO_BLOCKS,
}

View File

@@ -14,7 +14,6 @@
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
@@ -38,45 +37,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
# auto_docstring
class WanCoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded conditions and runs the denoising process.
Components:
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
Inputs:
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
num_frames (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "wan"
block_classes = [
WanTextInputStep,
@@ -88,11 +49,14 @@ class WanCoreDenoiseStep(SequentialPipelineBlocks):
@property
def description(self):
return "denoise block that takes encoded conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanDenoiseStep` is used to denoise the latents\n"
)
# ====================
@@ -100,51 +64,7 @@ class WanCoreDenoiseStep(SequentialPipelineBlocks):
# ====================
# auto_docstring
class WanBlocks(SequentialPipelineBlocks):
"""
Modular pipeline blocks for Wan2.1.
Components:
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer
(`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) vae (`AutoencoderKLWan`) video_processor
(`VideoProcessor`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`None`, *optional*, defaults to 512):
TODO: Add description.
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
num_frames (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
output_type (`str`, *optional*, defaults to np):
The output type of the decoded videos
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "wan"
block_classes = [
WanTextEncoderStep,
@@ -155,8 +75,9 @@ class WanBlocks(SequentialPipelineBlocks):
@property
def description(self):
return "Modular pipeline blocks for Wan2.1."
@property
def outputs(self):
return [OutputParam.template("videos")]
return (
"Modular pipeline blocks for Wan2.1.\n"
+ "- `WanTextEncoderStep` is used to encode the text\n"
+ "- `WanCoreDenoiseStep` is used to denoise the latents\n"
+ "- `WanVaeDecoderStep` is used to decode the latents to images"
)

View File

@@ -14,7 +14,6 @@
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
WanPrepareLatentsStep,
WanSetTimestepsStep,
@@ -39,50 +38,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# inputs(text) -> set_timesteps -> prepare_latents -> denoise
# auto_docstring
class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded conditions and runs the denoising process.
Components:
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
Configs:
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
noise stages.
Inputs:
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
num_frames (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "wan"
block_classes = [
WanTextInputStep,
@@ -94,11 +50,14 @@ class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
@property
def description(self):
return "denoise block that takes encoded conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22DenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
@@ -106,55 +65,7 @@ class Wan22CoreDenoiseStep(SequentialPipelineBlocks):
# ====================
# auto_docstring
class Wan22Blocks(SequentialPipelineBlocks):
"""
Modular pipeline for text-to-video using Wan2.2.
Components:
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) transformer
(`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`)
transformer_2 (`WanTransformer3DModel`) vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
Configs:
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
noise stages.
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`None`, *optional*, defaults to 512):
TODO: Add description.
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
num_frames (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
output_type (`str`, *optional*, defaults to np):
The output type of the decoded videos
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "wan"
block_classes = [
WanTextEncoderStep,
@@ -169,8 +80,9 @@ class Wan22Blocks(SequentialPipelineBlocks):
@property
def description(self):
return "Modular pipeline for text-to-video using Wan2.2."
@property
def outputs(self):
return [OutputParam.template("videos")]
return (
"Modular pipeline for text-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `Wan22CoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -14,7 +14,6 @@
from ...utils import logging
from ..modular_pipeline import SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
@@ -41,36 +40,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# auto_docstring
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
"""
Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent
representation
Components:
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
Inputs:
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
resized_image (`Image`):
TODO: Add description.
first_frame_latents (`Tensor`):
video latent representation with the first frame image condition
image_condition_latents (`Tensor | NoneType`):
TODO: Add description.
"""
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@@ -86,52 +56,7 @@ class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
# auto_docstring
class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded text and image latent conditions and runs the denoising process.
Components:
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
Configs:
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
noise stages.
Inputs:
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
num_frames (`None`, *optional*):
TODO: Add description.
image_condition_latents (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
@@ -150,11 +75,15 @@ class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
@property
def description(self):
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `Wan22Image2VideoDenoiseStep` is used to denoise the latents in wan2.2\n"
)
# ====================
@@ -162,57 +91,7 @@ class Wan22Image2VideoCoreDenoiseStep(SequentialPipelineBlocks):
# ====================
# auto_docstring
class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
"""
Modular pipeline for image-to-video using Wan2.2.
Components:
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`) vae
(`AutoencoderKLWan`) video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler
(`UniPCMultistepScheduler`) guider_2 (`ClassifierFreeGuidance`) transformer_2 (`WanTransformer3DModel`)
Configs:
boundary_ratio (default: 0.875): The boundary ratio to divide the denoising loop into high noise and low
noise stages.
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`None`, *optional*, defaults to 512):
TODO: Add description.
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
output_type (`str`, *optional*, defaults to np):
The output type of the decoded videos
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
@@ -229,8 +108,10 @@ class Wan22Image2VideoBlocks(SequentialPipelineBlocks):
@property
def description(self):
return "Modular pipeline for image-to-video using Wan2.2."
@property
def outputs(self):
return [OutputParam.template("videos")]
return (
"Modular pipeline for image-to-video using Wan2.2.\n"
+ " - `WanTextEncoderStep` encodes the text\n"
+ " - `WanImage2VideoVaeEncoderStep` encodes the image\n"
+ " - `Wan22Image2VideoCoreDenoiseStep` denoes the latents\n"
+ " - `WanVaeDecoderStep` decodes the latents to video frames\n"
)

View File

@@ -14,7 +14,6 @@
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
WanAdditionalInputsStep,
WanPrepareLatentsStep,
@@ -46,29 +45,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# wan2.1 I2V (first frame only)
# auto_docstring
class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
"""
Image2Video Image Encoder step that resize the image and encode the image to generate the image embeddings
Components:
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
Inputs:
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
Outputs:
resized_image (`Image`):
TODO: Add description.
image_embeds (`Tensor`):
The image embeddings
"""
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageEncoderStep]
block_names = ["image_resize", "image_encoder"]
@@ -79,34 +56,7 @@ class WanImage2VideoImageEncoderStep(SequentialPipelineBlocks):
# wan2.1 FLF2V (first and last frame)
# auto_docstring
class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
"""
FLF2V Image Encoder step that resize and encode and encode the first and last frame images to generate the image
embeddings
Components:
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
Inputs:
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
last_image (`Image`):
The last frameimage
Outputs:
resized_image (`Image`):
TODO: Add description.
resized_last_image (`Image`):
TODO: Add description.
image_embeds (`Tensor`):
The image embeddings
"""
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanImageCropResizeStep, WanFirstLastFrameImageEncoderStep]
block_names = ["image_resize", "last_image_resize", "image_encoder"]
@@ -117,36 +67,7 @@ class WanFLF2VImageEncoderStep(SequentialPipelineBlocks):
# wan2.1 Auto Image Encoder
# auto_docstring
class WanAutoImageEncoderStep(AutoPipelineBlocks):
"""
Image Encoder step that encode the image to generate the image embeddingsThis is an auto pipeline block that works
for image2video tasks. - `WanFLF2VImageEncoderStep` (flf2v) is used when `last_image` is provided. -
`WanImage2VideoImageEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is
not provided, step will be skipped.
Components:
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`)
Inputs:
image (`Image`, *optional*):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
last_image (`Image`, *optional*):
The last frameimage
Outputs:
resized_image (`Image`):
TODO: Add description.
resized_last_image (`Image`):
TODO: Add description.
image_embeds (`Tensor`):
The image embeddings
"""
block_classes = [WanFLF2VImageEncoderStep, WanImage2VideoImageEncoderStep]
block_names = ["flf2v_image_encoder", "image2video_image_encoder"]
block_trigger_inputs = ["last_image", "image"]
@@ -169,36 +90,7 @@ class WanAutoImageEncoderStep(AutoPipelineBlocks):
# wan2.1 I2V (first frame only)
# auto_docstring
class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
"""
Image2Video Vae Image Encoder step that resize the image and encode the first frame image to its latent
representation
Components:
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
Inputs:
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
resized_image (`Image`):
TODO: Add description.
first_frame_latents (`Tensor`):
video latent representation with the first frame image condition
image_condition_latents (`Tensor | NoneType`):
TODO: Add description.
"""
model_name = "wan-i2v"
block_classes = [WanImageResizeStep, WanVaeEncoderStep, WanPrepareFirstFrameLatentsStep]
block_names = ["image_resize", "vae_encoder", "prepare_first_frame_latents"]
@@ -209,40 +101,7 @@ class WanImage2VideoVaeEncoderStep(SequentialPipelineBlocks):
# wan2.1 FLF2V (first and last frame)
# auto_docstring
class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
"""
FLF2V Vae Image Encoder step that resize and encode and encode the first and last frame images to generate the
latent conditions
Components:
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
Inputs:
image (`Image`):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
last_image (`Image`):
The last frameimage
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
resized_image (`Image`):
TODO: Add description.
resized_last_image (`Image`):
TODO: Add description.
first_last_frame_latents (`Tensor`):
video latent representation with the first and last frame images condition
image_condition_latents (`Tensor | NoneType`):
TODO: Add description.
"""
model_name = "wan-i2v"
block_classes = [
WanImageResizeStep,
@@ -258,44 +117,7 @@ class WanFLF2VVaeEncoderStep(SequentialPipelineBlocks):
# wan2.1 Auto Vae Encoder
# auto_docstring
class WanAutoVaeEncoderStep(AutoPipelineBlocks):
"""
Vae Image Encoder step that encode the image to generate the image latentsThis is an auto pipeline block that works
for image2video tasks. - `WanFLF2VVaeEncoderStep` (flf2v) is used when `last_image` is provided. -
`WanImage2VideoVaeEncoderStep` (image2video) is used when `image` is provided. - if `last_image` or `image` is not
provided, step will be skipped.
Components:
vae (`AutoencoderKLWan`) video_processor (`VideoProcessor`)
Inputs:
image (`Image`, *optional*):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
last_image (`Image`, *optional*):
The last frameimage
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
resized_image (`Image`):
TODO: Add description.
resized_last_image (`Image`):
TODO: Add description.
first_last_frame_latents (`Tensor`):
video latent representation with the first and last frame images condition
image_condition_latents (`Tensor | NoneType`):
TODO: Add description.
first_frame_latents (`Tensor`):
video latent representation with the first frame image condition
"""
model_name = "wan-i2v"
block_classes = [WanFLF2VVaeEncoderStep, WanImage2VideoVaeEncoderStep]
block_names = ["flf2v_vae_encoder", "image2video_vae_encoder"]
@@ -319,53 +141,7 @@ class WanAutoVaeEncoderStep(AutoPipelineBlocks):
# wan2.1 I2V core denoise (support both I2V and FLF2V)
# inputs (text + image_condition_latents) -> set_timesteps -> prepare_latents -> denoise (latents)
# auto_docstring
class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded text and image latent conditions and runs the denoising process.
Components:
transformer (`WanTransformer3DModel`) scheduler (`UniPCMultistepScheduler`) guider (`ClassifierFreeGuidance`)
Inputs:
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`Tensor`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`Tensor`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
num_frames (`None`, *optional*):
TODO: Add description.
image_condition_latents (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
image_embeds (`Tensor`):
TODO: Add description.
Outputs:
batch_size (`int`):
Number of prompts, the final batch size of model inputs should be batch_size * num_videos_per_prompt
dtype (`dtype`):
Data type of model tensor inputs (determined by `transformer.dtype`)
latents (`Tensor`):
The initial latents to use for the denoising process
"""
model_name = "wan-i2v"
block_classes = [
WanTextInputStep,
@@ -384,7 +160,15 @@ class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
@property
def description(self):
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `WanTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `WanAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
+ " - `WanImage2VideoDenoiseStep` is used to denoise the latents\n"
)
# ====================
@@ -393,64 +177,7 @@ class WanImage2VideoCoreDenoiseStep(SequentialPipelineBlocks):
# wan2.1 Image2Video Auto Blocks
# auto_docstring
class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for image-to-video using Wan.
Supported workflows:
- `image2video`: requires `image`, `prompt`
- `flf2v`: requires `last_image`, `image`, `prompt`
Components:
text_encoder (`UMT5EncoderModel`) tokenizer (`AutoTokenizer`) guider (`ClassifierFreeGuidance`)
image_processor (`CLIPImageProcessor`) image_encoder (`CLIPVisionModel`) vae (`AutoencoderKLWan`)
video_processor (`VideoProcessor`) transformer (`WanTransformer3DModel`) scheduler
(`UniPCMultistepScheduler`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`None`, *optional*, defaults to 512):
TODO: Add description.
image (`Image`, *optional*):
TODO: Add description.
height (`int`, *optional*, defaults to 480):
TODO: Add description.
width (`int`, *optional*, defaults to 832):
TODO: Add description.
last_image (`Image`, *optional*):
The last frameimage
num_frames (`int`, *optional*, defaults to 81):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_videos_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_condition_latents (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 50):
TODO: Add description.
timesteps (`None`, *optional*):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
attention_kwargs (`None`, *optional*):
TODO: Add description.
image_embeds (`Tensor`):
TODO: Add description.
output_type (`str`, *optional*, defaults to np):
The output type of the decoded videos
Outputs:
videos (`list`):
The generated videos.
"""
model_name = "wan-i2v"
block_classes = [
WanTextEncoderStep,
@@ -467,15 +194,10 @@ class WanImage2VideoAutoBlocks(SequentialPipelineBlocks):
"decode",
]
_workflow_map = {
"image2video": {"image": True, "prompt": True},
"flf2v": {"last_image": True, "image": True, "prompt": True},
}
@property
def description(self):
return "Auto Modular pipeline for image-to-video using Wan."
@property
def outputs(self):
return [OutputParam.template("videos")]
return (
"Auto Modular pipeline for image-to-video using Wan.\n"
+ "- for I2V workflow, all you need to provide is `image`"
+ "- for FLF2V workflow, all you need to provide is `last_image` and `image`"
)

View File

@@ -21,7 +21,12 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["modular_blocks_z_image"] = ["ZImageAutoBlocks"]
_import_structure["decoders"] = ["ZImageVaeDecoderStep"]
_import_structure["encoders"] = ["ZImageTextEncoderStep", "ZImageVaeImageEncoderStep"]
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"ZImageAutoBlocks",
]
_import_structure["modular_pipeline"] = ["ZImageModularPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
@@ -31,7 +36,12 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .modular_blocks_z_image import ZImageAutoBlocks
from .decoders import ZImageVaeDecoderStep
from .encoders import ZImageTextEncoderStep
from .modular_blocks import (
ALL_BLOCKS,
ZImageAutoBlocks,
)
from .modular_pipeline import ZImageModularPipeline
else:
import sys

View File

@@ -0,0 +1,191 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import InsertableDict
from .before_denoise import (
ZImageAdditionalInputsStep,
ZImagePrepareLatentsStep,
ZImagePrepareLatentswithImageStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImageTextInputStep,
)
from .decoders import ZImageVaeDecoderStep
from .denoise import (
ZImageDenoiseStep,
)
from .encoders import (
ZImageTextEncoderStep,
ZImageVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# z-image
# text2image
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
ZImageTextInputStep,
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageDenoiseStep,
]
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
@property
def description(self):
return (
"denoise block that takes encoded conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
)
# z-image: image2image
## denoise
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
block_classes = [
ZImageTextInputStep,
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImagePrepareLatentswithImageStep,
ZImageDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"prepare_latents",
"set_timesteps",
"set_timesteps_with_strength",
"prepare_latents_with_image",
"denoise",
]
@property
def description(self):
return (
"denoise block that takes encoded text and image latent conditions and runs the denoising process.\n"
+ "This is a sequential pipeline blocks:\n"
+ " - `ZImageTextInputStep` is used to adjust the batch size of the model inputs\n"
+ " - `ZImageAdditionalInputsStep` is used to adjust the batch size of the latent conditions\n"
+ " - `ZImagePrepareLatentsStep` is used to prepare the latents\n"
+ " - `ZImageSetTimestepsStep` is used to set the timesteps\n"
+ " - `ZImageSetTimestepsWithStrengthStep` is used to set the timesteps with strength\n"
+ " - `ZImagePrepareLatentswithImageStep` is used to prepare the latents with image\n"
+ " - `ZImageDenoiseStep` is used to denoise the latents\n"
)
## auto blocks
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
block_classes = [
ZImageImage2ImageCoreDenoiseStep,
ZImageCoreDenoiseStep,
]
block_names = ["image2image", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and image2image tasks."
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
)
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
block_classes = [ZImageVaeImageEncoderStep]
block_names = ["vae_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self) -> str:
return "Vae Image Encoder step that encode the image to generate the image latents"
+"This is an auto pipeline block that works for image2image tasks."
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
+" - if `image` is not provided, step will be skipped."
class ZImageAutoBlocks(SequentialPipelineBlocks):
block_classes = [
ZImageTextEncoderStep,
ZImageAutoVaeImageEncoderStep,
ZImageAutoDenoiseStep,
ZImageVaeDecoderStep,
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
@property
def description(self) -> str:
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage.\n"
+" - for text-to-image generation, all you need to provide is `prompt`\n"
+" - for image-to-image generation, you need to provide `image`\n"
+" - if `image` is not provided, step will be skipped."
# presets
TEXT2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("input", ZImageTextInputStep),
("prepare_latents", ZImagePrepareLatentsStep),
("set_timesteps", ZImageSetTimestepsStep),
("denoise", ZImageDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
IMAGE2IMAGE_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_encoder", ZImageVaeImageEncoderStep),
("input", ZImageTextInputStep),
("additional_inputs", ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"])),
("prepare_latents", ZImagePrepareLatentsStep),
("set_timesteps", ZImageSetTimestepsStep),
("set_timesteps_with_strength", ZImageSetTimestepsWithStrengthStep),
("prepare_latents_with_image", ZImagePrepareLatentswithImageStep),
("denoise", ZImageDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
AUTO_BLOCKS = InsertableDict(
[
("text_encoder", ZImageTextEncoderStep),
("vae_encoder", ZImageAutoVaeImageEncoderStep),
("denoise", ZImageAutoDenoiseStep),
("decode", ZImageVaeDecoderStep),
]
)
ALL_BLOCKS = {
"text2image": TEXT2IMAGE_BLOCKS,
"image2image": IMAGE2IMAGE_BLOCKS,
"auto": AUTO_BLOCKS,
}

View File

@@ -1,334 +0,0 @@
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ...utils import logging
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
from ..modular_pipeline_utils import OutputParam
from .before_denoise import (
ZImageAdditionalInputsStep,
ZImagePrepareLatentsStep,
ZImagePrepareLatentswithImageStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImageTextInputStep,
)
from .decoders import ZImageVaeDecoderStep
from .denoise import (
ZImageDenoiseStep,
)
from .encoders import (
ZImageTextEncoderStep,
ZImageVaeImageEncoderStep,
)
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
# ====================
# 1. DENOISE
# ====================
# text2image: inputs(text) -> set_timesteps -> prepare_latents -> denoise
# auto_docstring
class ZImageCoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded conditions and runs the denoising process.
Components:
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`list`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`list`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`int`, *optional*):
TODO: Add description.
width (`int`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 9):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
**denoiser_input_fields (`None`, *optional*):
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
block_classes = [
ZImageTextInputStep,
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageDenoiseStep,
]
block_names = ["input", "prepare_latents", "set_timesteps", "denoise"]
@property
def description(self):
return "denoise block that takes encoded conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
# image2image: inputs(text + image_latents) -> prepare_latents -> set_timesteps -> set_timesteps_with_strength -> prepare_latents_with_image -> denoise
# auto_docstring
class ZImageImage2ImageCoreDenoiseStep(SequentialPipelineBlocks):
"""
denoise block that takes encoded text and image latent conditions and runs the denoising process.
Components:
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`list`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`list`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`, *optional*, defaults to 9):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
**denoiser_input_fields (`None`, *optional*):
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
block_classes = [
ZImageTextInputStep,
ZImageAdditionalInputsStep(image_latent_inputs=["image_latents"]),
ZImagePrepareLatentsStep,
ZImageSetTimestepsStep,
ZImageSetTimestepsWithStrengthStep,
ZImagePrepareLatentswithImageStep,
ZImageDenoiseStep,
]
block_names = [
"input",
"additional_inputs",
"prepare_latents",
"set_timesteps",
"set_timesteps_with_strength",
"prepare_latents_with_image",
"denoise",
]
@property
def description(self):
return "denoise block that takes encoded text and image latent conditions and runs the denoising process."
@property
def outputs(self):
return [OutputParam.template("latents")]
# auto_docstring
class ZImageAutoDenoiseStep(AutoPipelineBlocks):
"""
Denoise step that iteratively denoise the latents. This is a auto pipeline block that works for text2image and
image2image tasks. - `ZImageCoreDenoiseStep` (text2image) for text2image tasks. -
`ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks. - if `image_latents` is provided,
`ZImageImage2ImageCoreDenoiseStep` will be used.
- if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.
Components:
transformer (`ZImageTransformer2DModel`) scheduler (`FlowMatchEulerDiscreteScheduler`) guider
(`ClassifierFreeGuidance`)
Inputs:
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
prompt_embeds (`list`):
Pre-generated text embeddings. Can be generated from text_encoder step.
negative_prompt_embeds (`list`, *optional*):
Pre-generated negative text embeddings. Can be generated from text_encoder step.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
**denoiser_input_fields (`None`, *optional*):
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
Outputs:
latents (`Tensor`):
Denoised latents.
"""
block_classes = [
ZImageImage2ImageCoreDenoiseStep,
ZImageCoreDenoiseStep,
]
block_names = ["image2image", "text2image"]
block_trigger_inputs = ["image_latents", None]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. "
"This is a auto pipeline block that works for text2image and image2image tasks."
" - `ZImageCoreDenoiseStep` (text2image) for text2image tasks."
" - `ZImageImage2ImageCoreDenoiseStep` (image2image) for image2image tasks."
+ " - if `image_latents` is provided, `ZImageImage2ImageCoreDenoiseStep` will be used.\n"
+ " - if `image_latents` is not provided, `ZImageCoreDenoiseStep` will be used.\n"
)
# auto_docstring
class ZImageAutoVaeImageEncoderStep(AutoPipelineBlocks):
"""
Vae Image Encoder step that encode the image to generate the image latents
Components:
vae (`AutoencoderKL`) image_processor (`VaeImageProcessor`)
Inputs:
image (`Image`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
Outputs:
image_latents (`Tensor`):
video latent representation with the first frame image condition
"""
block_classes = [ZImageVaeImageEncoderStep]
block_names = ["vae_encoder"]
block_trigger_inputs = ["image"]
@property
def description(self) -> str:
return "Vae Image Encoder step that encode the image to generate the image latents"
+"This is an auto pipeline block that works for image2image tasks."
+" - `ZImageVaeImageEncoderStep` is used when `image` is provided."
+" - if `image` is not provided, step will be skipped."
# auto_docstring
class ZImageAutoBlocks(SequentialPipelineBlocks):
"""
Auto Modular pipeline for text-to-image and image-to-image using ZImage.
Supported workflows:
- `text2image`: requires `prompt`
- `image2image`: requires `image`, `prompt`
Components:
text_encoder (`Qwen3Model`) tokenizer (`Qwen2Tokenizer`) guider (`ClassifierFreeGuidance`) vae
(`AutoencoderKL`) image_processor (`VaeImageProcessor`) transformer (`ZImageTransformer2DModel`) scheduler
(`FlowMatchEulerDiscreteScheduler`)
Inputs:
prompt (`None`, *optional*):
TODO: Add description.
negative_prompt (`None`, *optional*):
TODO: Add description.
max_sequence_length (`None`, *optional*, defaults to 512):
TODO: Add description.
image (`Image`, *optional*):
TODO: Add description.
height (`None`, *optional*):
TODO: Add description.
width (`None`, *optional*):
TODO: Add description.
generator (`None`, *optional*):
TODO: Add description.
num_images_per_prompt (`None`, *optional*, defaults to 1):
TODO: Add description.
image_latents (`None`, *optional*):
TODO: Add description.
latents (`Tensor | NoneType`):
TODO: Add description.
num_inference_steps (`None`):
TODO: Add description.
sigmas (`None`, *optional*):
TODO: Add description.
strength (`None`, *optional*, defaults to 0.6):
TODO: Add description.
**denoiser_input_fields (`None`, *optional*):
The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.
output_type (`str`, *optional*, defaults to pil):
The type of the output images, can be 'pil', 'np', 'pt'
Outputs:
images (`list`):
Generated images.
"""
block_classes = [
ZImageTextEncoderStep,
ZImageAutoVaeImageEncoderStep,
ZImageAutoDenoiseStep,
ZImageVaeDecoderStep,
]
block_names = ["text_encoder", "vae_encoder", "denoise", "decode"]
_workflow_map = {
"text2image": {"prompt": True},
"image2image": {"image": True, "prompt": True},
}
@property
def description(self) -> str:
return "Auto Modular pipeline for text-to-image and image-to-image using ZImage."
@property
def outputs(self):
return [OutputParam.template("images")]

View File

@@ -6,6 +6,7 @@ from ..utils import (
_LazyModule,
get_objects_from_module,
is_flax_available,
is_k_diffusion_available,
is_librosa_available,
is_note_seq_available,
is_onnx_available,
@@ -237,7 +238,6 @@ else:
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
@@ -466,6 +466,21 @@ else:
]
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils import (
dummy_torch_and_transformers_and_k_diffusion_objects,
)
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_k_diffusion_objects))
else:
_import_structure["stable_diffusion_k_diffusion"] = [
"StableDiffusionKDiffusionPipeline",
"StableDiffusionXLKDiffusionPipeline",
]
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()
@@ -668,7 +683,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (
@@ -887,6 +901,17 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionOnnxPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_k_diffusion_objects import *
else:
from .stable_diffusion_k_diffusion import (
StableDiffusionKDiffusionPipeline,
StableDiffusionXLKDiffusionPipeline,
)
try:
if not (is_torch_available() and is_transformers_available() and is_sentencepiece_available()):
raise OptionalDependencyNotAvailable()

View File

@@ -502,10 +502,6 @@ class AudioLDM2Pipeline(DiffusionPipeline):
text_input_ids,
attention_mask=attention_mask,
)
# Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+)
# otherwise use it directly (Transformers v4)
if hasattr(prompt_embeds, "pooler_output"):
prompt_embeds = prompt_embeds.pooler_output
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
prompt_embeds = prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state
@@ -614,10 +610,6 @@ class AudioLDM2Pipeline(DiffusionPipeline):
uncond_input_ids,
attention_mask=negative_attention_mask,
)
# Extract the pooler output if it's a BaseModelOutputWithPooling (Transformers v5+)
# otherwise use it directly (Transformers v4)
if hasattr(negative_prompt_embeds, "pooler_output"):
negative_prompt_embeds = negative_prompt_embeds.pooler_output
# append the seq-len dim: (bs, hidden_size) -> (bs, seq_len, hidden_size)
negative_prompt_embeds = negative_prompt_embeds[:, None, :]
# make sure that we attend to this single hidden-state

View File

@@ -54,7 +54,6 @@ from .flux import (
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -175,8 +174,6 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
("helios", HeliosPipeline),
("helios-pyramid", HeliosPyramidPipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),

View File

@@ -287,9 +287,6 @@ class Cosmos2_5_PredictBasePipeline(DiffusionPipeline):
truncation=True,
padding="max_length",
)
input_ids = (
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
)
input_ids = torch.LongTensor(input_ids)
input_ids_batch.append(input_ids)

View File

@@ -17,6 +17,9 @@ from typing import Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -51,13 +54,11 @@ else:
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def _maybe_pad_or_trim_video(video: torch.Tensor, num_frames: int):
def _maybe_pad_video(video: torch.Tensor, num_frames: int):
n_pad_frames = num_frames - video.shape[2]
if n_pad_frames > 0:
last_frame = video[:, :, -1:, :, :]
video = torch.cat((video, last_frame.repeat(1, 1, n_pad_frames, 1, 1)), dim=2)
elif num_frames < video.shape[2]:
video = video[:, :, :num_frames, :, :]
return video
@@ -133,8 +134,8 @@ EXAMPLE_DOC_STRING = """
>>> controls = [Image.fromarray(x.numpy()) for x in controls.permute(1, 2, 3, 0)]
>>> export_to_video(controls, "edge_controlled_video_edge.mp4", fps=30)
>>> # Transfer inference with controls.
>>> video = pipe(
... video=input_video[:num_frames],
... controls=controls,
... controls_conditioning_scale=1.0,
... prompt=prompt,
@@ -148,7 +149,7 @@ EXAMPLE_DOC_STRING = """
class Cosmos2_5_TransferPipeline(DiffusionPipeline):
r"""
Pipeline for Cosmos Transfer2.5, supporting auto-regressive inference.
Pipeline for Cosmos Transfer2.5 base model.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
@@ -165,14 +166,12 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
controlnet ([`CosmosControlNetModel`]):
ControlNet used to condition generation on control inputs.
"""
model_cpu_offload_seq = "text_encoder->transformer->controlnet->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
# We mark safety_checker as optional here to get around some test failures, but it is not really optional
_optional_components = ["safety_checker"]
_optional_components = ["safety_checker", "controlnet"]
_exclude_from_cpu_offload = ["safety_checker"]
def __init__(
@@ -182,8 +181,8 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
transformer: CosmosTransformer3DModel,
vae: AutoencoderKLWan,
scheduler: UniPCMultistepScheduler,
controlnet: CosmosControlNetModel,
safety_checker: Optional[CosmosSafetyChecker] = None,
controlnet: Optional[CosmosControlNetModel],
safety_checker: CosmosSafetyChecker = None,
):
super().__init__()
@@ -263,9 +262,6 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
truncation=True,
padding="max_length",
)
input_ids = (
input_ids["input_ids"] if not isinstance(input_ids, list) and "input_ids" in input_ids else input_ids
)
input_ids = torch.LongTensor(input_ids)
input_ids_batch.append(input_ids)
@@ -385,11 +381,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
num_frames_in: int = 93,
num_frames_out: int = 93,
do_classifier_free_guidance: bool = True,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
num_cond_latent_frames: int = 0,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
@@ -404,14 +399,10 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
W = width // self.vae_scale_factor_spatial
shape = (B, C, T, H, W)
if latents is not None:
if latents.shape[1:] != shape[1:]:
raise ValueError(f"Unexpected `latents` shape, got {latents.shape}, expected {shape}.")
latents = latents.to(device=device, dtype=dtype)
else:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
if num_frames_in == 0:
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
cond_mask = torch.zeros((B, 1, T, H, W), dtype=latents.dtype, device=latents.device)
cond_indicator = torch.zeros((B, 1, T, 1, 1), dtype=latents.dtype, device=latents.device)
@@ -441,12 +432,16 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
latents_std = self.latents_std.to(device=device, dtype=dtype)
cond_latents = (cond_latents - latents_mean) / latents_std
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device=device, dtype=dtype)
padding_shape = (B, 1, T, H, W)
ones_padding = latents.new_ones(padding_shape)
zeros_padding = latents.new_zeros(padding_shape)
cond_indicator = latents.new_zeros(B, 1, latents.size(2), 1, 1)
cond_indicator[:, :, 0:num_cond_latent_frames, :, :] = 1.0
cond_indicator = latents.new_zeros(1, 1, latents.size(2), 1, 1)
cond_mask = cond_indicator * ones_padding + (1 - cond_indicator) * zeros_padding
return (
@@ -456,7 +451,34 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
cond_indicator,
)
# Modified from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def _encode_controls(
self,
controls: Optional[torch.Tensor],
height: int,
width: int,
num_frames: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator | list[torch.Generator] | None,
) -> Optional[torch.Tensor]:
if controls is None:
return None
control_video = self.video_processor.preprocess_video(controls, height, width)
control_video = _maybe_pad_video(control_video, num_frames)
control_video = control_video.to(device=device, dtype=self.vae.dtype)
control_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator) for vid in control_video
]
control_latents = torch.cat(control_latents, dim=0).to(dtype)
latents_mean = self.latents_mean.to(device=device, dtype=dtype)
latents_std = self.latents_std.to(device=device, dtype=dtype)
control_latents = (control_latents - latents_mean) / latents_std
return control_latents
# Copied from diffusers.pipelines.cosmos.pipeline_cosmos_text2world.CosmosTextToWorldPipeline.check_inputs
def check_inputs(
self,
prompt,
@@ -464,25 +486,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width,
prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
num_ar_conditional_frames=None,
num_ar_latent_conditional_frames=None,
num_frames_per_chunk=None,
num_frames=None,
conditional_frame_timestep=0.1,
):
if width <= 0 or height <= 0 or height % 16 != 0 or width % 16 != 0:
raise ValueError(
f"`height` and `width` have to be divisible by 16 (& positive) but are {height} and {width}."
)
if num_frames is not None and num_frames <= 0:
raise ValueError(f"`num_frames` has to be a positive integer when provided but is {num_frames}.")
if conditional_frame_timestep < 0 or conditional_frame_timestep > 1:
raise ValueError(
"`conditional_frame_timestep` has to be a float in the [0, 1] interval but is "
f"{conditional_frame_timestep}."
)
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -503,46 +509,6 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if num_ar_latent_conditional_frames is not None and num_ar_conditional_frames is not None:
raise ValueError(
"Provide only one of `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`, not both."
)
if num_ar_latent_conditional_frames is None and num_ar_conditional_frames is None:
raise ValueError("Provide either `num_ar_conditional_frames` or `num_ar_latent_conditional_frames`.")
if num_ar_latent_conditional_frames is not None and num_ar_latent_conditional_frames < 0:
raise ValueError("`num_ar_latent_conditional_frames` must be >= 0.")
if num_ar_conditional_frames is not None and num_ar_conditional_frames < 0:
raise ValueError("`num_ar_conditional_frames` must be >= 0.")
if num_ar_latent_conditional_frames is not None:
num_ar_conditional_frames = max(
0, (num_ar_latent_conditional_frames - 1) * self.vae_scale_factor_temporal + 1
)
min_chunk_len = self.vae_scale_factor_temporal + 1
if num_frames_per_chunk < min_chunk_len:
logger.warning(f"{num_frames_per_chunk=} must be larger than {min_chunk_len=}, setting to min_chunk_len")
num_frames_per_chunk = min_chunk_len
max_frames_by_rope = None
if getattr(self.transformer.config, "max_size", None) is not None:
max_frames_by_rope = max(
size // patch
for size, patch in zip(self.transformer.config.max_size, self.transformer.config.patch_size)
)
if num_frames_per_chunk > max_frames_by_rope:
raise ValueError(
f"{num_frames_per_chunk=} is too large for RoPE setting ({max_frames_by_rope=}). "
"Please reduce `num_frames_per_chunk`."
)
if num_ar_conditional_frames >= num_frames_per_chunk:
raise ValueError(
f"{num_ar_conditional_frames=} must be smaller than {num_frames_per_chunk=} for chunked generation."
)
return num_frames_per_chunk
@property
def guidance_scale(self):
return self._guidance_scale
@@ -567,22 +533,23 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
controls: PipelineImageInput | List[PipelineImageInput],
controls_conditioning_scale: Union[float, List[float]] = 1.0,
image: PipelineImageInput | None = None,
video: List[PipelineImageInput] | None = None,
prompt: Union[str, List[str]] | None = None,
negative_prompt: Union[str, List[str]] = DEFAULT_NEGATIVE_PROMPT,
height: int = 704,
width: Optional[int] = None,
num_frames: Optional[int] = None,
num_frames_per_chunk: int = 93,
width: int | None = None,
num_frames: int = 93,
num_inference_steps: int = 36,
guidance_scale: float = 3.0,
num_videos_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
num_videos_per_prompt: Optional[int] = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
controls: Optional[PipelineImageInput | List[PipelineImageInput]] = None,
controls_conditioning_scale: float | list[float] = 1.0,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -590,26 +557,24 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 512,
conditional_frame_timestep: float = 0.1,
num_ar_conditional_frames: Optional[int] = 1,
num_ar_latent_conditional_frames: Optional[int] = None,
):
r"""
`controls` drive the conditioning through ControlNet. Controls are assumed to be pre-processed, e.g. edge maps
are pre-computed.
The call function to the pipeline for generation. Supports three modes:
Setting `num_frames` will restrict the total number of frames output, if not provided or assigned to None
(default) then the number of output frames will match the input `controls`.
- **Text2World**: `image=None`, `video=None`, `prompt` provided. Generates a world clip.
- **Image2World**: `image` provided, `video=None`, `prompt` provided. Conditions on a single frame.
- **Video2World**: `video` provided, `image=None`, `prompt` provided. Conditions on an input clip.
Auto-regressive inference is supported and thus a sliding window of `num_frames_per_chunk` frames are used per
denoising loop. In addition, when auto-regressive inference is performed, the previous
`num_ar_latent_conditional_frames` or `num_ar_conditional_frames` are used to condition the following denoising
inference loops.
Set `num_frames=93` (default) to produce a world video, or `num_frames=1` to produce a single image frame (the
above in "*2Image mode").
Outputs follow `output_type` (e.g., `"pil"` returns a list of `num_frames` PIL images per prompt).
Args:
controls (`PipelineImageInput`, `List[PipelineImageInput]`):
Control image or video input used by the ControlNet.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional single image for Image2World conditioning. Must be `None` when `video` is provided.
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
Optional input video for Video2World conditioning. Must be `None` when `image` is provided.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide generation. Required unless `prompt_embeds` is supplied.
height (`int`, defaults to `704`):
@@ -617,10 +582,9 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
width (`int`, *optional*):
The width in pixels of the generated image. If not provided, this will be determined based on the
aspect ratio of the input and the provided height.
num_frames (`int`, *optional*):
Number of output frames. Defaults to `None` to output the same number of frames as the input
`controls`.
num_inference_steps (`int`, defaults to `36`):
num_frames (`int`, defaults to `93`):
Number of output frames. Use `93` for world (video) generation; set to `1` to return a single frame.
num_inference_steps (`int`, defaults to `35`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `3.0`):
@@ -634,9 +598,13 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs. Can be used to
tweak the same generation with different prompts. If not provided, a latents tensor is generated by
sampling using the supplied random `generator`.
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
controls (`PipelineImageInput`, `List[PipelineImageInput]`, *optional*):
Control image or video input used by the ControlNet. If `None`, ControlNet is skipped.
controls_conditioning_scale (`float` or `List[float]`, *optional*, defaults to `1.0`):
The scale factor(s) for the ControlNet outputs. A single float is broadcast to all control blocks.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
@@ -659,18 +627,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
max_sequence_length (`int`, defaults to `512`):
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
the prompt is shorter than this length, it will be padded.
num_ar_conditional_frames (`int`, *optional*, defaults to `1`):
Number of frames to condition on subsequent inference loops in auto-regressive inference, i.e. for the
second chunk and onwards. Only used if `num_ar_latent_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
num_ar_latent_conditional_frames (`int`, *optional*):
Number of latent frames to condition on subsequent inference loops in auto-regressive inference, i.e.
for the second chunk and onwards. Only used if `num_ar_conditional_frames` is `None`.
This is only used when auto-regressive inference is performed, i.e. when the number of frames in
controls is > num_frames_per_chunk
Examples:
Returns:
@@ -690,40 +647,21 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if width is None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, list):
frame = frame[0]
if isinstance(frame, (torch.Tensor, np.ndarray)):
if frame.ndim == 5:
frame = frame[0, 0]
elif frame.ndim == 4:
frame = frame[0]
frame = image or video[0] if image or video else None
if frame is None and controls is not None:
frame = controls[0] if isinstance(controls, list) else controls
if isinstance(frame, (torch.Tensor, np.ndarray)) and len(frame.shape) == 4:
frame = controls[0]
if isinstance(frame, PIL.Image.Image):
if frame is None:
width = int((height + 16) * (1280 / 720))
elif isinstance(frame, PIL.Image.Image):
width = int((height + 16) * (frame.width / frame.height))
else:
if frame.ndim != 3:
raise ValueError("`controls` must contain 3D frames in CHW format.")
width = int((height + 16) * (frame.shape[2] / frame.shape[1])) # NOTE: assuming C H W
num_frames_per_chunk = self.check_inputs(
prompt,
height,
width,
prompt_embeds,
callback_on_step_end_tensor_inputs,
num_ar_conditional_frames,
num_ar_latent_conditional_frames,
num_frames_per_chunk,
num_frames,
conditional_frame_timestep,
)
if num_ar_latent_conditional_frames is not None:
num_cond_latent_frames = num_ar_latent_conditional_frames
num_ar_conditional_frames = max(0, (num_cond_latent_frames - 1) * self.vae_scale_factor_temporal + 1)
else:
num_cond_latent_frames = max(0, (num_ar_conditional_frames - 1) // self.vae_scale_factor_temporal + 1)
# Check inputs. Raise error if not correct
self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs)
self._guidance_scale = guidance_scale
self._current_timestep = None
@@ -768,137 +706,102 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
vae_dtype = self.vae.dtype
transformer_dtype = self.transformer.dtype
if getattr(self.transformer.config, "img_context_dim_in", None):
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
img_context = torch.zeros(
batch_size,
self.transformer.config.img_context_num_tokens,
self.transformer.config.img_context_dim_in,
device=prompt_embeds.device,
dtype=transformer_dtype,
)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
num_frames_in = None
if image is not None:
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for image input (given {batch_size})")
image = torchvision.transforms.functional.to_tensor(image).unsqueeze(0)
video = torch.cat([image, torch.zeros_like(image).repeat(num_frames - 1, 1, 1, 1)], dim=0)
video = video.unsqueeze(0)
num_frames_in = 1
elif video is None:
video = torch.zeros(batch_size, num_frames, 3, height, width, dtype=torch.uint8)
num_frames_in = 0
else:
num_frames_in = len(video)
if batch_size != 1:
raise ValueError(f"batch_size must be 1 for video input (given {batch_size})")
assert video is not None
video = self.video_processor.preprocess_video(video, height, width)
# pad with last frame (for video2world)
num_frames_out = num_frames
video = _maybe_pad_video(video, num_frames_out)
assert num_frames_in <= num_frames_out, f"expected ({num_frames_in=}) <= ({num_frames_out=})"
video = video.to(device=device, dtype=vae_dtype)
num_channels_latents = self.transformer.config.in_channels - 1
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=num_channels_latents,
height=height,
width=width,
num_frames_in=num_frames_in,
num_frames_out=num_frames,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
latents=latents,
)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
cond_mask = cond_mask.to(transformer_dtype)
controls_latents = None
if controls is not None:
controls_latents = self._encode_controls(
controls,
height=height,
width=width,
num_frames=num_frames,
dtype=transformer_dtype,
device=device,
generator=generator,
)
if num_videos_per_prompt > 1:
img_context = img_context.repeat_interleave(num_videos_per_prompt, dim=0)
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
encoder_hidden_states = (prompt_embeds, img_context)
neg_encoder_hidden_states = (negative_prompt_embeds, img_context)
else:
encoder_hidden_states = prompt_embeds
neg_encoder_hidden_states = negative_prompt_embeds
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
control_video = self.video_processor.preprocess_video(controls, height, width)
if control_video.shape[0] != batch_size:
if control_video.shape[0] == 1:
control_video = control_video.repeat(batch_size, 1, 1, 1, 1)
else:
raise ValueError(
f"Expected controls batch size {batch_size} to match prompt batch size, but got {control_video.shape[0]}."
gt_velocity = (latents - cond_latent) * cond_mask
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
num_frames_out = control_video.shape[2]
if num_frames is not None:
num_frames_out = min(num_frames_out, num_frames)
control_video = _maybe_pad_or_trim_video(control_video, num_frames_out)
# chunk information
num_latent_frames_per_chunk = (num_frames_per_chunk - 1) // self.vae_scale_factor_temporal + 1
chunk_stride = num_frames_per_chunk - num_ar_conditional_frames
chunk_idxs = [
(start_idx, min(start_idx + num_frames_per_chunk, num_frames_out))
for start_idx in range(0, num_frames_out - num_ar_conditional_frames, chunk_stride)
]
video_chunks = []
latents_mean = self.latents_mean.to(dtype=vae_dtype, device=device)
latents_std = self.latents_std.to(dtype=vae_dtype, device=device)
def decode_latents(latents):
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(dtype=self.vae.dtype, device=device), return_dict=False)[0]
return video
latents_arg = latents
initial_num_cond_latent_frames = 0
latent_chunks = []
num_chunks = len(chunk_idxs)
total_steps = num_inference_steps * num_chunks
with self.progress_bar(total=total_steps) as progress_bar:
for chunk_idx, (start_idx, end_idx) in enumerate(chunk_idxs):
if chunk_idx == 0:
prev_output = torch.zeros((batch_size, num_frames_per_chunk, 3, height, width), dtype=vae_dtype)
prev_output = self.video_processor.preprocess_video(prev_output, height, width)
else:
prev_output = video_chunks[-1].clone()
if num_ar_conditional_frames > 0:
prev_output[:, :, :num_ar_conditional_frames] = prev_output[:, :, -num_ar_conditional_frames:]
prev_output[:, :, num_ar_conditional_frames:] = -1 # -1 == 0 in processed video space
else:
prev_output.fill_(-1)
chunk_video = prev_output.to(device=device, dtype=vae_dtype)
chunk_video = _maybe_pad_or_trim_video(chunk_video, num_frames_per_chunk)
latents, cond_latent, cond_mask, cond_indicator = self.prepare_latents(
video=chunk_video,
batch_size=batch_size * num_videos_per_prompt,
num_channels_latents=self.transformer.config.in_channels - 1,
height=height,
width=width,
num_frames_in=chunk_video.shape[2],
num_frames_out=num_frames_per_chunk,
do_classifier_free_guidance=self.do_classifier_free_guidance,
dtype=torch.float32,
device=device,
generator=generator,
num_cond_latent_frames=initial_num_cond_latent_frames
if chunk_idx == 0
else num_cond_latent_frames,
latents=latents_arg,
)
cond_mask = cond_mask.to(transformer_dtype)
cond_timestep = torch.ones_like(cond_indicator) * conditional_frame_timestep
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
chunk_control_video = control_video[:, :, start_idx:end_idx, ...].to(
device=device, dtype=self.vae.dtype
)
chunk_control_video = _maybe_pad_or_trim_video(chunk_control_video, num_frames_per_chunk)
if isinstance(generator, list):
controls_latents = [
retrieve_latents(self.vae.encode(chunk_control_video[i].unsqueeze(0)), generator=generator[i])
for i in range(chunk_control_video.shape[0])
]
else:
controls_latents = [
retrieve_latents(self.vae.encode(vid.unsqueeze(0)), generator=generator)
for vid in chunk_control_video
]
controls_latents = torch.cat(controls_latents, dim=0).to(transformer_dtype)
controls_latents = (controls_latents - latents_mean) / latents_std
# Denoising loop
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self._num_timesteps = len(timesteps)
gt_velocity = (latents - cond_latent) * cond_mask
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t.cpu().item()
# NOTE: assumes sigma(t) \in [0, 1]
sigma_t = (
torch.tensor(self.scheduler.sigmas[i].item())
.unsqueeze(0)
.to(device=device, dtype=transformer_dtype)
)
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
in_latents = cond_mask * cond_latent + (1 - cond_mask) * latents
in_latents = in_latents.to(transformer_dtype)
in_timestep = cond_indicator * cond_timestep + (1 - cond_indicator) * sigma_t
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -911,18 +814,20 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
noise_pred = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=encoder_hidden_states,
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
noise_pred = gt_velocity + noise_pred * (1 - cond_mask)
if self.do_classifier_free_guidance:
if self.do_classifier_free_guidance:
control_blocks = None
if controls_latents is not None and self.controlnet is not None:
control_output = self.controlnet(
controls_latents=controls_latents,
latents=in_latents,
@@ -935,50 +840,46 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
)
control_blocks = control_output[0]
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
noise_pred_neg = self.transformer(
hidden_states=in_latents,
timestep=in_timestep,
encoder_hidden_states=neg_encoder_hidden_states, # NOTE: negative prompt
block_controlnet_hidden_states=control_blocks,
condition_mask=cond_mask,
padding_mask=padding_mask,
return_dict=False,
)[0]
# NOTE: replace velocity (noise_pred_neg) with gt_velocity for conditioning inputs only
noise_pred_neg = gt_velocity + noise_pred_neg * (1 - cond_mask)
noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_neg)
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# call the callback, if provided
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == total_steps - 1 or ((i + 1) % self.scheduler.order == 0):
progress_bar.update()
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
video_chunks.append(decode_latents(latents).detach().cpu())
latent_chunks.append(latents.detach().cpu())
if XLA_AVAILABLE:
xm.mark_step()
self._current_timestep = None
if not output_type == "latent":
video_chunks = [
chunk[:, :, num_ar_conditional_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(video_chunks)
]
video = torch.cat(video_chunks, dim=2)
video = video[:, :, :num_frames_out, ...]
latents_mean = self.latents_mean.to(latents.device, latents.dtype)
latents_std = self.latents_std.to(latents.device, latents.dtype)
latents = latents * latents_std + latents_mean
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
video = self._match_num_frames(video, num_frames)
assert self.safety_checker is not None
self.safety_checker.to(device)
@@ -995,13 +896,7 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)
video = self.video_processor.postprocess_video(video, output_type=output_type)
else:
latent_T = (num_frames_out - 1) // self.vae_scale_factor_temporal + 1
latent_chunks = [
chunk[:, :, num_cond_latent_frames:, ...] if chunk_idx != 0 else chunk
for chunk_idx, chunk in enumerate(latent_chunks)
]
video = torch.cat(latent_chunks, dim=2)
video = video[:, :, :latent_T, ...]
video = latents
# Offload all models
self.maybe_free_model_hooks()
@@ -1010,3 +905,19 @@ class Cosmos2_5_TransferPipeline(DiffusionPipeline):
return (video,)
return CosmosPipelineOutput(frames=video)
def _match_num_frames(self, video: torch.Tensor, target_num_frames: int) -> torch.Tensor:
if target_num_frames <= 0 or video.shape[2] == target_num_frames:
return video
frames_per_latent = max(self.vae_scale_factor_temporal, 1)
video = torch.repeat_interleave(video, repeats=frames_per_latent, dim=2)
current_frames = video.shape[2]
if current_frames < target_num_frames:
pad = video[:, :, -1:, :, :].repeat(1, 1, target_num_frames - current_frames, 1, 1)
video = torch.cat([video, pad], dim=2)
elif current_frames > target_num_frames:
video = video[:, :, :target_num_frames]
return video

View File

@@ -1,48 +0,0 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_helios"] = ["HeliosPipeline"]
_import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_helios import HeliosPipeline
from .pipeline_helios_pyramid import HeliosPyramidPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

Some files were not shown because too many files have changed in this diff Show More