Compare commits

..

82 Commits

Author SHA1 Message Date
sayakpaul
403d3f20f7 small nits. 2026-03-05 14:05:53 +05:30
Sayak Paul
441224ac00 Merge branch 'main' into rae 2026-03-05 12:27:28 +05:30
Christopher
20364fe5a2 adding lora support to z-image controlnet pipelines (#13200)
adding lora to z-image controlnet pipelines
2026-03-05 10:05:53 +05:30
Sayak Paul
3902145b38 [lora] fix zimage lora conversion to support for more lora. (#13209)
fix zimage lora conversion to support for more lora.
2026-03-05 08:24:20 +05:30
dg845
af0bed007a Merge branch 'main' into rae 2026-03-04 17:04:49 -08:00
Álvaro Somoza
5570f817da [Z-Image] Fix more do_classifier_free_guidance thresholds (#13212)
fix
2026-03-04 10:11:55 -10:00
dg845
33f785b444 Add Helios-14B Video Generation Pipelines (#13208)
* [1/N] add helios

* fix test

* make fix-copies

* change script path

* fix cus script

* update docs

* fix documented check

* update links for docs and examples

* change default config

* small refactor

* add test

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove register_buffer for _scale_cache

* fix non-cuda devices error

* remove "handle the case when timestep is 2D"

* refactor HeliosMultiTermMemoryPatch and process_input_hidden_states

* Update src/diffusers/pipelines/helios/pipeline_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/helios/pipeline_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix calculate_shift

* Update src/diffusers/pipelines/helios/pipeline_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* rewritten `einops` in pure `torch`

* fix: pass patch_size to apply_schedule_shift instead of hardcoding

* remove the logics of 'vae_decode_type'

* move some validation into check_inputs()

* rename helios scheduler & merge all into one step()

* add some details to doc

* move dmd  step() logics from pipeline to scheduler

* change to Python 3.9+ style type

* fix NoneType error

* refactor DMD scheduler's set_timestep

* change rope related vars name

* fix stage2 sample

* fix dmd sample

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* remove redundant & refactor norm_out

* Update src/diffusers/pipelines/helios/pipeline_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* change "is_keep_x0" to "keep_first_frame"

* use a more intuitive name

* refactor dynamic_time_shifting

* remove use_dynamic_shifting args

* remove usage of UniPCMultistepScheduler

* separate stage2 sample to HeliosPyramidPipeline

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/models/transformers/transformer_helios.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* fix transformer

* use a more intuitive name

* update example script

* fix requirements

* remove redudant attention mask

* fix

* optimize pipelines

* make style .

* update TYPE_CHECKING

* change to use torch.split

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* derive memory patch sizes from patch_size multiples

* remove some hardcoding

* move some checks into check_inputs

* refactor sample_block_noise

* optimize encoding chunks logits for v2v

* use num_history_latent_frames = sum(history_sizes)

* Update src/diffusers/pipelines/helios/pipeline_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* remove redudant optimized_scale

* Update src/diffusers/pipelines/helios/pipeline_helios_pyramid.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* use more descriptive name

* optimize history_latents

* remove not used "num_inference_steps"

* removed redudant "pyramid_num_stages"

* add "is_cfg_zero_star" and "is_distilled" to HeliosPyramidPipeline

* remove redudant

* change example scripts name

* change example scripts name

* correct docs

* update example

* update docs

* Update tests/models/transformers/test_models_transformer_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update tests/models/transformers/test_models_transformer_helios.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* separate HeliosDMDScheduler

* fix numerical stability issue:

* Update src/diffusers/schedulers/scheduling_helios_dmd.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_helios_dmd.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_helios_dmd.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_helios_dmd.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_helios_dmd.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* remove redudant

* small refactor

* remove use_interpolate_prompt logits

* simplified model test

* fallbackt to BaseModelTesterConfig

* remove _maybe_expand_t2v_lora_for_i2v

* fix HeliosLoraLoaderMixin

* update docs

* use randn_tensor for test

* fix doc typo

* optimize code

* mark torch.compile xfail

* change paper name

* Make get_dummy_inputs deterministic using self.generator

* Set less strict threshold for test_save_load_float16 test for Helios pipeline

* make style and make quality

* Preparation for merging

* add torch.Generator

* Fix HeliosPipelineOutput doc path

* Fix Helios related (optimize docs & remove redudant) (#13210)

* fix docs

* remove redudant

* remove redudant

* fix group offload

* Removed fixes for group offload

---------

Co-authored-by: yuanshenghai <yuanshenghai@bytedance.com>
Co-authored-by: Shenghai Yuan <140951558+SHYuanBest@users.noreply.github.com>
Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: SHYuanBest <shyuan-cs@hotmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-04 21:31:43 +05:30
Shenghai Yuan
06ccde9490 Fix group-offloading bug (#13211)
* Implement synchronous onload for offloaded parameters

Add fallback synchronous onload for conditionally-executed modules.

* add test for new code path about group-offloading

* Update tests/hooks/test_group_offloading.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* use unittest.skipIf and update the comment

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-04 20:34:13 +05:30
Ando
ed9bcfd7a9 Merge branch 'huggingface:main' into rae 2026-03-04 19:21:12 +08:00
Kashif Rasul
05d3edca66 use randn_tensor 2026-03-04 10:16:07 +00:00
Kashif Rasul
f4ec0f1443 remove unittest 2026-03-04 10:12:40 +00:00
Kashif Rasul
fa016b196c rename 2026-03-04 09:55:54 +00:00
Kashif Rasul
33d98a85da fix api 2026-03-04 09:55:25 +00:00
jiqing-feng
88798242bc cogvideo example: Distribute VAE video encoding across processes in CogVideoX LoRA training (#13207)
* Distribute VAE video encoding across processes in CogVideoX LoRA training

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Apply style fixes

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2026-03-04 15:09:01 +05:30
Kashif Rasul
14d918ee88 Merge branch 'main' into rae 2026-03-04 10:18:06 +01:00
Kashif Rasul
bc59324a2f Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:12:50 +01:00
Kashif Rasul
b9a5266cec _noising takes a generator 2026-03-04 09:12:19 +00:00
Kashif Rasul
876e930780 remove optional 2026-03-04 09:09:09 +00:00
Kashif Rasul
df1af7d907 Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 10:04:23 +01:00
Kashif Rasul
af75d8b9e2 inline 2026-03-04 09:03:37 +00:00
Kashif Rasul
e805be989e use buffer 2026-03-04 09:00:09 +00:00
Kashif Rasul
3958fda3bf Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-04 09:53:33 +01:00
Kashif Rasul
196f8a36c7 error out as soon as possible and add comments 2026-03-04 08:52:08 +00:00
Sayak Paul
4a2833c1c2 [Modular] implement requirements validation for custom blocks (#12196)
* feat: implement requirements validation for custom blocks.

* up

* unify.

* up

* add tests

* Apply suggestions from code review

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* reviewer feedback.

* [docs] validation for custom blocks (#13156)

validation

* move to tmp_path fixture.

* propagate to conditional and loopsequential blocks.

* up

* remove collected tests

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2026-03-04 12:19:08 +05:30
YiYi Xu
1fe688a651 [modular] not pass trust_remote_code to external repos (#13204)
* add

* update warn

* add a test

* updaqte

* update_component with custom model

* add more tests

* Apply suggestion from @DN6

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* up

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
2026-03-03 02:36:36 -10:00
Sayak Paul
9c0f96b303 Merge branch 'main' into rae 2026-03-03 17:06:14 +05:30
Kashif Rasul
bc71889852 update training script 2026-03-03 09:10:58 +00:00
Kashif Rasul
3a6689518f add dispatch forward and update conversion script 2026-03-03 09:03:28 +00:00
YiYi Xu
bbbcdd87bd [modular]Update model card to include workflow (#13195)
* up

* up

* update

* remove test

---------

Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-161-123.ec2.internal>
Co-authored-by: yiyi@huggingface.co <yiyi@ip-26-0-160-103.ec2.internal>
2026-03-02 20:50:07 -10:00
Dhruv Nair
47e8faf3b9 Clean up accidental files (#13202)
update
2026-03-03 00:35:58 +05:30
David El Malih
c2fdd2d048 docs: improve docstring scheduling_ipndm.py (#13198)
Improve docstring scheduling ipndm
2026-03-02 09:42:55 -08:00
Dhruv Nair
84ff061b1d [Modular] Save Modular Pipeline weights to Hub (#13168)
* update

* update

* update

* update

* update

* update
2026-03-02 22:20:42 +05:30
Dhruv Nair
3fd14f1acf [AutoModel] Allow registering auto_map to model config (#13186)
* update

* update
2026-03-02 22:13:25 +05:30
Dhruv Nair
e7fe4ce92f [AutoModel] Fix bug with subfolders and local model paths when loading custom code (#13197)
* update

* update
2026-03-02 17:44:25 +05:30
Sayak Paul
3d9085565b remove db utils from benchmarking (#13199) 2026-03-02 16:39:56 +05:30
Sayak Paul
5b54496131 [tests] enable cpu offload test in torchao without compilation. (#12704)
enable cpu offload test in torchao without compilation.
2026-03-02 15:03:58 +05:30
Sayak Paul
fcdd759e39 [chore] updates in the pypi publication workflow. (#12805)
* updates in the pypi publication workflow.

* change to 3.10
2026-03-02 14:34:49 +05:30
Kashif Rasul
5817416a19 fix test 2026-03-02 08:11:31 +00:00
Kashif Rasul
e834e498b2 _strip_final_layernorm_affine for training script 2026-02-28 19:40:19 +00:00
Kashif Rasul
f15873af72 strip final layernorm when converting 2026-02-28 19:35:21 +00:00
Sayak Paul
bff48d317e Merge branch 'main' into rae 2026-02-28 22:01:01 +05:30
Kashif Rasul
cd86873ea6 make quality 2026-02-28 16:28:04 +00:00
Kashif Rasul
34787e5b9b use ModelTesterMixin and AutoencoderTesterMixin 2026-02-28 16:22:47 +00:00
Kashif Rasul
9ada5768e5 remove config 2026-02-28 16:05:19 +00:00
Kashif Rasul
8861a8082a fix slow test 2026-02-28 15:57:10 +00:00
Kashif Rasul
03e757ca73 Encoder is frozen 2026-02-28 15:35:28 +00:00
Kashif Rasul
c717498fa3 use image url 2026-02-28 15:08:56 +00:00
Kashif Rasul
1b4a43f59d Update src/diffusers/models/autoencoders/autoencoder_rae.py
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:43:20 +01:00
Kashif Rasul
6a78767864 Update examples/research_projects/autoencoder_rae/README.md
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-02-27 11:42:45 +01:00
Kashif Rasul
663b580418 latebt normalization buffers are now always registered with no-op defaults 2026-02-26 10:45:30 +00:00
Kashif Rasul
d965cabe79 fix conversion script review 2026-02-26 10:44:27 +00:00
Kashif Rasul
5c85781519 fix train script to use pretrained 2026-02-26 10:38:47 +00:00
Kashif Rasul
c71cb44299 Merge branch 'rae' of https://github.com/Ando233/diffusers into rae 2026-02-26 10:30:32 +00:00
Kashif Rasul
dca59233f6 address reviews 2026-02-26 10:30:26 +00:00
Kashif Rasul
b3ffd6344a cleanups 2026-02-26 10:26:30 +00:00
Kashif Rasul
7debd07541 Merge branch 'main' into rae 2026-02-26 11:08:08 +01:00
Kashif Rasul
b297868201 fixes from pretrained weights 2026-02-25 13:38:22 +00:00
Kashif Rasul
28a02eb226 undo last change 2026-02-23 10:05:24 +00:00
Kashif Rasul
61885f37e3 added encoder_image_size config 2026-02-23 09:59:26 +00:00
Kashif Rasul
c68b812cb0 fix entrypoint for instantiating the AutoencoderRAE 2026-02-23 09:40:18 +00:00
Kashif Rasul
d8b2983b9e Merge branch 'main' into rae 2026-02-17 10:10:40 +01:00
Kashif Rasul
d06b501850 fix training script 2026-02-16 13:00:00 +00:00
Kashif Rasul
a4fc9f64b2 simplify mixins 2026-02-16 12:52:20 +00:00
Kashif Rasul
fc5295951a cleanup 2026-02-16 12:40:36 +00:00
Kashif Rasul
96520c4ff1 move loss to training script 2026-02-16 12:35:18 +00:00
Kashif Rasul
d3cbd5a60b fix argument 2026-02-16 00:03:54 +00:00
Kashif Rasul
906d79a432 input and ground truth sizes have to be the same 2026-02-16 00:02:27 +00:00
Kashif Rasul
9522e68a5b example traiing script 2026-02-15 23:56:19 +00:00
Kashif Rasul
6a9bde6964 remove unneeded class 2026-02-15 23:55:06 +00:00
Kashif Rasul
e6d449933d use attention 2026-02-15 23:50:52 +00:00
Kashif Rasul
7cbbf271f3 use imports 2026-02-15 23:33:30 +00:00
Kashif Rasul
202b14f6a4 add rae to diffusers script 2026-02-15 23:19:53 +00:00
Kashif Rasul
0d59b22732 cleanup 2026-02-15 23:19:13 +00:00
Kashif Rasul
d7cb12470b use mean and std convention 2026-02-15 22:57:02 +00:00
Kashif Rasul
f06ea7a901 fix latent_mean / latent_var init types to accept config-friendly inputs 2026-02-15 22:51:36 +00:00
Kashif Rasul
25bc9e334c initial doc 2026-02-15 22:44:46 +00:00
Kashif Rasul
24acab0bcc make fix-copies 2026-02-15 22:44:16 +00:00
Kashif Rasul
0850c8cdc9 fix formatting 2026-02-15 22:39:59 +00:00
Kashif Rasul
3ecf89d044 Merge branch 'main' into rae 2026-02-15 23:05:44 +01:00
Ando
a3926d77d7 Merge branch 'main' into rae 2026-01-28 20:31:20 +08:00
wangyuqi
f82cecc298 feat: finish first version of autoencoder_rae 2026-01-28 20:19:31 +08:00
wangyuqi
382aad0a6c feat: implement three RAE encoders(dinov2, siglip2, mae) 2026-01-25 02:54:35 +08:00
66 changed files with 8989 additions and 501 deletions

View File

@@ -62,20 +62,6 @@ jobs:
with:
name: benchmark_test_reports
path: benchmarks/${{ env.BASE_PATH }}
# TODO: enable this once the connection problem has been resolved.
- name: Update benchmarking results to DB
env:
PGDATABASE: metrics
PGHOST: ${{ secrets.DIFFUSERS_BENCHMARKS_PGHOST }}
PGUSER: transformers_benchmarks
PGPASSWORD: ${{ secrets.DIFFUSERS_BENCHMARKS_PGPASSWORD }}
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
run: |
git config --global --add safe.directory /__w/diffusers/diffusers
commit_id=$GITHUB_SHA
commit_msg=$(git show -s --format=%s "$commit_id" | cut -c1-70)
cd benchmarks && python populate_into_db.py "$BRANCH_NAME" "$commit_id" "$commit_msg"
- name: Report success status
if: ${{ success() }}

View File

@@ -54,7 +54,6 @@ jobs:
python -m pip install --upgrade pip
pip install -U setuptools wheel twine
pip install -U torch --index-url https://download.pytorch.org/whl/cpu
pip install -U transformers
- name: Build the dist files
run: python setup.py bdist_wheel && python setup.py sdist
@@ -69,6 +68,8 @@ jobs:
run: |
pip install diffusers && pip uninstall diffusers -y
pip install -i https://test.pypi.org/simple/ diffusers
pip install -U transformers
python utils/print_env.py
python -c "from diffusers import __version__; print(__version__)"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('fusing/unet-ldm-dummy-update'); pipe()"
python -c "from diffusers import DiffusionPipeline; pipe = DiffusionPipeline.from_pretrained('hf-internal-testing/tiny-stable-diffusion-pipe', safety_checker=None); pipe('ah suh du')"

View File

@@ -1,166 +0,0 @@
import argparse
import os
import sys
import gpustat
import pandas as pd
import psycopg2
import psycopg2.extras
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
register_adapter(dict, Json)
FINAL_CSV_FILENAME = "collated_results.csv"
# https://github.com/huggingface/transformers/blob/593e29c5e2a9b17baec010e8dc7c1431fed6e841/benchmark/init_db.sql#L27
BENCHMARKS_TABLE_NAME = "benchmarks"
MEASUREMENTS_TABLE_NAME = "model_measurements"
def _init_benchmark(conn, branch, commit_id, commit_msg):
gpu_stats = gpustat.GPUStatCollection.new_query()
metadata = {"gpu_name": gpu_stats[0]["name"]}
repository = "huggingface/diffusers"
with conn.cursor() as cur:
cur.execute(
f"INSERT INTO {BENCHMARKS_TABLE_NAME} (repository, branch, commit_id, commit_message, metadata) VALUES (%s, %s, %s, %s, %s) RETURNING benchmark_id",
(repository, branch, commit_id, commit_msg, metadata),
)
benchmark_id = cur.fetchone()[0]
print(f"Initialised benchmark #{benchmark_id}")
return benchmark_id
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"branch",
type=str,
help="The branch name on which the benchmarking is performed.",
)
parser.add_argument(
"commit_id",
type=str,
help="The commit hash on which the benchmarking is performed.",
)
parser.add_argument(
"commit_msg",
type=str,
help="The commit message associated with the commit, truncated to 70 characters.",
)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = parse_args()
try:
conn = psycopg2.connect(
host=os.getenv("PGHOST"),
database=os.getenv("PGDATABASE"),
user=os.getenv("PGUSER"),
password=os.getenv("PGPASSWORD"),
)
print("DB connection established successfully.")
except Exception as e:
print(f"Problem during DB init: {e}")
sys.exit(1)
try:
benchmark_id = _init_benchmark(
conn=conn,
branch=args.branch,
commit_id=args.commit_id,
commit_msg=args.commit_msg,
)
except Exception as e:
print(f"Problem during initializing benchmark: {e}")
sys.exit(1)
cur = conn.cursor()
df = pd.read_csv(FINAL_CSV_FILENAME)
# Helper to cast values (or None) given a dtype
def _cast_value(val, dtype: str):
if pd.isna(val):
return None
if dtype == "text":
return str(val).strip()
if dtype == "float":
try:
return float(val)
except ValueError:
return None
if dtype == "bool":
s = str(val).strip().lower()
if s in ("true", "t", "yes", "1"):
return True
if s in ("false", "f", "no", "0"):
return False
if val in (1, 1.0):
return True
if val in (0, 0.0):
return False
return None
return val
try:
rows_to_insert = []
for _, row in df.iterrows():
scenario = _cast_value(row.get("scenario"), "text")
model_cls = _cast_value(row.get("model_cls"), "text")
num_params_B = _cast_value(row.get("num_params_B"), "float")
flops_G = _cast_value(row.get("flops_G"), "float")
time_plain_s = _cast_value(row.get("time_plain_s"), "float")
mem_plain_GB = _cast_value(row.get("mem_plain_GB"), "float")
time_compile_s = _cast_value(row.get("time_compile_s"), "float")
mem_compile_GB = _cast_value(row.get("mem_compile_GB"), "float")
fullgraph = _cast_value(row.get("fullgraph"), "bool")
mode = _cast_value(row.get("mode"), "text")
# If "github_sha" column exists in the CSV, cast it; else default to None
if "github_sha" in df.columns:
github_sha = _cast_value(row.get("github_sha"), "text")
else:
github_sha = None
measurements = {
"scenario": scenario,
"model_cls": model_cls,
"num_params_B": num_params_B,
"flops_G": flops_G,
"time_plain_s": time_plain_s,
"mem_plain_GB": mem_plain_GB,
"time_compile_s": time_compile_s,
"mem_compile_GB": mem_compile_GB,
"fullgraph": fullgraph,
"mode": mode,
"github_sha": github_sha,
}
rows_to_insert.append((benchmark_id, measurements))
# Batch-insert all rows
insert_sql = f"""
INSERT INTO {MEASUREMENTS_TABLE_NAME} (
benchmark_id,
measurements
)
VALUES (%s, %s);
"""
psycopg2.extras.execute_batch(cur, insert_sql, rows_to_insert)
conn.commit()
cur.close()
conn.close()
except Exception as e:
print(f"Exception: {e}")
sys.exit(1)

View File

@@ -194,6 +194,8 @@
title: Model accelerators and hardware
- isExpanded: false
sections:
- local: using-diffusers/helios
title: Helios
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/sdxl
@@ -350,6 +352,8 @@
title: FluxTransformer2DModel
- local: api/models/glm_image_transformer2d
title: GlmImageTransformer2DModel
- local: api/models/helios_transformer3d
title: HeliosTransformer3DModel
- local: api/models/hidream_image_transformer
title: HiDreamImageTransformer2DModel
- local: api/models/hunyuan_transformer2d
@@ -456,6 +460,8 @@
title: AutoencoderKLQwenImage
- local: api/models/autoencoder_kl_wan
title: AutoencoderKLWan
- local: api/models/autoencoder_rae
title: AutoencoderRAE
- local: api/models/consistency_decoder_vae
title: ConsistencyDecoderVAE
- local: api/models/autoencoder_oobleck
@@ -625,7 +631,6 @@
title: Image-to-image
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
@@ -674,6 +679,8 @@
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/helios
title: Helios
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/hunyuan_video15
@@ -745,6 +752,10 @@
title: FlowMatchEulerDiscreteScheduler
- local: api/schedulers/flow_match_heun_discrete
title: FlowMatchHeunDiscreteScheduler
- local: api/schedulers/helios_dmd
title: HeliosDMDScheduler
- local: api/schedulers/helios
title: HeliosScheduler
- local: api/schedulers/heun
title: HeunDiscreteScheduler
- local: api/schedulers/ipndm

View File

@@ -23,6 +23,7 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
- [`AuraFlowLoraLoaderMixin`] provides similar functions for [AuraFlow](https://huggingface.co/fal/AuraFlow).
- [`LTXVideoLoraLoaderMixin`] provides similar functions for [LTX-Video](https://huggingface.co/docs/diffusers/main/en/api/pipelines/ltx_video).
- [`SanaLoraLoaderMixin`] provides similar functions for [Sana](https://huggingface.co/docs/diffusers/main/en/api/pipelines/sana).
- [`HeliosLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/helios).
- [`HunyuanVideoLoraLoaderMixin`] provides similar functions for [HunyuanVideo](https://huggingface.co/docs/diffusers/main/en/api/pipelines/hunyuan_video).
- [`Lumina2LoraLoaderMixin`] provides similar functions for [Lumina2](https://huggingface.co/docs/diffusers/main/en/api/pipelines/lumina2).
- [`WanLoraLoaderMixin`] provides similar functions for [Wan](https://huggingface.co/docs/diffusers/main/en/api/pipelines/wan).
@@ -86,6 +87,10 @@ LoRA is a fast and lightweight training method that inserts and trains a signifi
[[autodoc]] loaders.lora_pipeline.SanaLoraLoaderMixin
## HeliosLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HeliosLoraLoaderMixin
## HunyuanVideoLoraLoaderMixin
[[autodoc]] loaders.lora_pipeline.HunyuanVideoLoraLoaderMixin

View File

@@ -0,0 +1,89 @@
<!-- Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# AutoencoderRAE
The Representation Autoencoder (RAE) model introduced in [Diffusion Transformers with Representation Autoencoders](https://huggingface.co/papers/2510.11690) by Boyang Zheng, Nanye Ma, Shengbang Tong, Saining Xie from NYU VISIONx.
RAE combines a frozen pretrained vision encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT-MAE-style decoder. In the two-stage RAE training recipe, the autoencoder is trained in stage 1 (reconstruction), and then a diffusion model is trained on the resulting latent space in stage 2 (generation).
The following RAE models are released and supported in Diffusers:
| Model | Encoder | Latent shape (224px input) |
|:------|:--------|:---------------------------|
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08) | DINOv2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08-i512) | DINOv2-base (512px) | 768 x 32 x 32 |
| [`nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-small-ViTXL-n08) | DINOv2-small | 384 x 16 x 16 |
| [`nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-dinov2-wReg-large-ViTXL-n08) | DINOv2-large | 1024 x 16 x 16 |
| [`nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-siglip2-base-p16-i256-ViTXL-n08) | SigLIP2-base | 768 x 16 x 16 |
| [`nyu-visionx/RAE-mae-base-p16-ViTXL-n08`](https://huggingface.co/nyu-visionx/RAE-mae-base-p16-ViTXL-n08) | MAE-base | 768 x 16 x 16 |
## Loading a pretrained model
```python
from diffusers import AutoencoderRAE
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
```
## Encoding and decoding a real image
```python
import torch
from diffusers import AutoencoderRAE
from diffusers.utils import load_image
from torchvision.transforms.functional import to_tensor, to_pil_image
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png")
image = image.convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to("cuda") # (1, 3, 224, 224), values in [0, 1]
with torch.no_grad():
latents = model.encode(x).latent # (1, 768, 16, 16)
recon = model.decode(latents).sample # (1, 3, 256, 256)
recon_image = to_pil_image(recon[0].clamp(0, 1).cpu())
recon_image.save("recon.png")
```
## Latent normalization
Some pretrained checkpoints include per-channel `latents_mean` and `latents_std` statistics for normalizing the latent space. When present, `encode` and `decode` automatically apply the normalization and denormalization, respectively.
```python
model = AutoencoderRAE.from_pretrained(
"nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08"
).to("cuda").eval()
# Latent normalization is handled automatically inside encode/decode
# when the checkpoint config includes latents_mean/latents_std.
with torch.no_grad():
latents = model.encode(x).latent # normalized latents
recon = model.decode(latents).sample
```
## AutoencoderRAE
[[autodoc]] AutoencoderRAE
- encode
- decode
- all
## DecoderOutput
[[autodoc]] models.autoencoders.vae.DecoderOutput

View File

@@ -0,0 +1,35 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License. -->
# HeliosTransformer3DModel
A 14B Real-Time Autogressive Diffusion Transformer model (support T2V, I2V and V2V) for 3D video-like data from [Helios](https://github.com/PKU-YuanGroup/Helios) was introduced in [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) by Peking University & ByteDance & etc.
The model can be loaded with the following code snippet.
```python
from diffusers import HeliosTransformer3DModel
# Best Quality
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="transformer", torch_dtype=torch.bfloat16)
# Intermediate Weight
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="transformer", torch_dtype=torch.bfloat16)
# Best Efficiency
transformer = HeliosTransformer3DModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="transformer", torch_dtype=torch.bfloat16)
```
## HeliosTransformer3DModel
[[autodoc]] HeliosTransformer3DModel
## Transformer2DModelOutput
[[autodoc]] models.modeling_outputs.Transformer2DModelOutput

View File

@@ -0,0 +1,465 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<a href="https://huggingface.co/docs/diffusers/main/en/tutorials/using_peft_for_inference" target="_blank" rel="noopener">
<img alt="LoRA" src="https://img.shields.io/badge/LoRA-d8b4fe?style=flat"/>
</a>
</div>
</div>
# Helios
[Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) from Peking University & ByteDance & etc, by Shenghai Yuan, Yuanyang Yin, Zongjian Li, Xinwei Huang, Xiao Yang, Li Yuan.
* <u>We introduce Helios, the first 14B video generation model that runs at 17 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching a strong baseline in quality.</u> We make breakthroughs along three key dimensions: (1) robustness to long-video drifting without commonly used anti-drift heuristics such as self-forcing, error banks, or keyframe sampling; (2) real-time generation without standard acceleration techniques such as KV-cache, causal masking, or sparse attention; and (3) training without parallelism or sharding frameworks, enabling image-diffusion-scale batch sizes while fitting up to four 14B models within 80 GB of GPU memory. Specifically, Helios is a 14B autoregressive diffusion model with a unified input representation that natively supports T2V, I2V, and V2V tasks. To mitigate drifting in long-video generation, we characterize its typical failure modes and propose simple yet effective training strategies that explicitly simulate drifting during training, while eliminating repetitive motion at its source. For efficiency, we heavily compress the historical and noisy context and reduce the number of sampling steps, yielding computational costs comparable to—or lower than—those of 1.3B video generative models. Moreover, we introduce infrastructure-level optimizations that accelerate both inference and training while reducing memory consumption. Extensive experiments demonstrate that Helios consistently outperforms prior methods on both short- and long-video generation. All the code and models are available at [this https URL](https://pku-yuangroup.github.io/Helios-Page).
The following Helios models are supported in Diffusers:
- [Helios-Base](https://huggingface.co/BestWishYsh/Helios-Base): Best Quality, with v-prediction, standard CFG and custom HeliosScheduler.
- [Helios-Mid](https://huggingface.co/BestWishYsh/Helios-Mid): Intermediate Weight, with v-prediction, CFG-Zero* and custom HeliosScheduler.
- [Helios-Distilled](https://huggingface.co/BestWishYsh/Helios-Distilled): Best Efficiency, with x0-prediction and custom HeliosDMDScheduler.
> [!TIP]
> Click on the Helios models in the right sidebar for more examples of video generation.
### Optimizing Memory and Inference Speed
The example below demonstrates how to generate a video from text optimized for memory or inference speed.
<hfoptions id="optimization">
<hfoption id="memory">
Refer to the [Reduce memory usage](../../optimization/memory) guide for more details about the various memory saving techniques.
The Helios model below requires ~19GB of VRAM.
```py
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.hooks.group_offloading import apply_group_offloading
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
# group-offloading
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.enable_group_offload(
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
record_stream=True,
)
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
```
</hfoption>
<hfoption id="inference speed">
[Compilation](../../optimization/fp16#torchcompile) is slow the first time but subsequent calls to the pipeline are faster. [Attention Backends](../../optimization/attention_backends) such as FlashAttention and SageAttention can significantly increase speed by optimizing the computation of the attention mechanism. [Caching](../../optimization/cache) may also speed up inference by storing and reusing intermediate outputs.
```py
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.utils import export_to_video
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
# attention backend
# pipeline.transformer.set_attention_backend("flash")
pipeline.transformer.set_attention_backend("_flash_3_hub") # For Hopper GPUs
# torch.compile
torch.backends.cudnn.benchmark = True
pipeline.text_encoder.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
pipeline.vae.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
pipeline.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False)
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Base
The example below demonstrates how to use Helios-Base to generate video based on text, image or video.
<hfoptions id="Helios-Base usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Base", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPipeline.from_pretrained(
"BestWishYsh/Helios-Base",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=99,
num_inference_steps=50,
guidance_scale=5.0,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_base_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Mid
The example below demonstrates how to use Helios-Mid to generate video based on text, image or video.
<hfoptions id="Helios-Mid usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPyramidPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Mid", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPyramidPipeline.from_pretrained(
"BestWishYsh/Helios-Mid",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=99,
pyramid_num_inference_steps_list=[20, 20, 20],
guidance_scale=5.0,
use_zero_init=True,
zero_steps=1,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_pyramid_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
### Generation with Helios-Distilled
The example below demonstrates how to use Helios-Distilled to generate video based on text, image or video.
<hfoptions id="Helios-Distilled usage">
<hfoption id="usage">
```python
import torch
from diffusers import AutoModel, HeliosPyramidPipeline
from diffusers.utils import export_to_video, load_video, load_image
vae = AutoModel.from_pretrained("BestWishYsh/Helios-Distilled", subfolder="vae", torch_dtype=torch.float32)
pipeline = HeliosPyramidPipeline.from_pretrained(
"BestWishYsh/Helios-Distilled",
vae=vae,
torch_dtype=torch.bfloat16
)
pipeline.to("cuda")
negative_prompt = """
Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality,
low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured,
misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards
"""
# For Text-to-Video
prompt = """
A vibrant tropical fish swimming gracefully among colorful coral reefs in a clear, turquoise ocean. The fish has bright blue
and yellow scales with a small, distinctive orange spot on its side, its fins moving fluidly. The coral reefs are alive with
a variety of marine life, including small schools of colorful fish and sea turtles gliding by. The water is crystal clear,
allowing for a view of the sandy ocean floor below. The reef itself is adorned with a mix of hard and soft corals in shades
of red, orange, and green. The photo captures the fish from a slightly elevated angle, emphasizing its lively movements and
the vivid colors of its surroundings. A close-up shot with dynamic movement.
"""
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_t2v_output.mp4", fps=24)
# For Image-to-Video
prompt = """
A towering emerald wave surges forward, its crest curling with raw power and energy. Sunlight glints off the translucent water,
illuminating the intricate textures and deep green hues within the waves body. A thick spray erupts from the breaking crest,
casting a misty veil that dances above the churning surface. As the perspective widens, the immense scale of the wave becomes
apparent, revealing the restless expanse of the ocean stretching beyond. The scene captures the oceans untamed beauty and
relentless force, with every droplet and ripple shimmering in the light. The dynamic motion and vivid colors evoke both awe and
respect for natures might.
"""
image_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/wave.jpg"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=load_image(image_path).resize((640, 384)),
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_i2v_output.mp4", fps=24)
# For Video-to-Video
prompt = """
A bright yellow Lamborghini Huracn Tecnica speeds along a curving mountain road, surrounded by lush green trees
under a partly cloudy sky. The car's sleek design and vibrant color stand out against the natural backdrop,
emphasizing its dynamic movement. The road curves gently, with a guardrail visible on one side, adding depth to
the scene. The motion blur captures the sense of speed and energy, creating a thrilling and exhilarating atmosphere.
A front-facing shot from a slightly elevated angle, highlighting the car's aggressive stance and the surrounding greenery.
"""
video_path = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/car.mp4"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
video=load_video(video_path),
num_frames=240,
pyramid_num_inference_steps_list=[2, 2, 2],
guidance_scale=1.0,
is_amplify_first_chunk=True,
generator=torch.Generator("cuda").manual_seed(42),
).frames[0]
export_to_video(output, "helios_distilled_v2v_output.mp4", fps=24)
```
</hfoption>
</hfoptions>
## HeliosPipeline
[[autodoc]] HeliosPipeline
- all
- __call__
## HeliosPyramidPipeline
[[autodoc]] HeliosPyramidPipeline
- all
- __call__
## HeliosPipelineOutput
[[autodoc]] pipelines.helios.pipeline_output.HeliosPipelineOutput

View File

@@ -0,0 +1,20 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HeliosScheduler
`HeliosScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
## HeliosScheduler
[[autodoc]] HeliosScheduler
scheduling_helios

View File

@@ -0,0 +1,20 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# HeliosDMDScheduler
`HeliosDMDScheduler` is based on the pyramidal flow-matching sampling introduced in [Helios](https://huggingface.co/papers).
## HeliosDMDScheduler
[[autodoc]] HeliosDMDScheduler
scheduling_helios_dmd

View File

@@ -332,4 +332,49 @@ Make your custom block work with Mellon's visual interface. See the [Mellon Cust
Browse the [Modular Diffusers Custom Blocks](https://huggingface.co/collections/diffusers/modular-diffusers-custom-blocks) collection for inspiration and ready-to-use blocks.
</hfoption>
</hfoptions>
</hfoptions>
## Dependencies
Declaring package dependencies in custom blocks prevents runtime import errors later on. Diffusers validates the dependencies and returns a warning if a package is missing or incompatible.
Set a `_requirements` attribute in your block class, mapping package names to version specifiers.
```py
from diffusers.modular_pipelines import PipelineBlock
class MyCustomBlock(PipelineBlock):
_requirements = {
"transformers": ">=4.44.0",
"sentencepiece": ">=0.2.0"
}
```
When there are blocks with different requirements, Diffusers merges their requirements.
```py
from diffusers.modular_pipelines import SequentialPipelineBlocks
class BlockA(PipelineBlock):
_requirements = {"transformers": ">=4.44.0"}
# ...
class BlockB(PipelineBlock):
_requirements = {"sentencepiece": ">=0.2.0"}
# ...
pipe = SequentialPipelineBlocks.from_blocks_dict({
"block_a": BlockA,
"block_b": BlockB,
})
```
When this block is saved with [`~ModularPipeline.save_pretrained`], the requirements are saved to the `modular_config.json` file. When this block is loaded, Diffusers checks each requirement against the current environment. If there is a mismatch or a package isn't found, Diffusers returns the following warning.
```md
# missing package
xyz-package was specified in the requirements but wasn't found in the current environment.
# version mismatch
xyz requirement 'specific-version' is not satisfied by the installed version 'actual-version'. Things might work unexpected.
```

View File

@@ -97,5 +97,32 @@ If the custom model inherits from the [`ModelMixin`] class, it gets access to th
> )
> ```
### Saving custom models
Use [`~ConfigMixin.register_for_auto_class`] to add the `auto_map` entry to `config.json` automatically when saving. This avoids having to manually edit the config file.
```py
# my_model.py
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin):
...
MyCustomModel.register_for_auto_class("AutoModel")
model = MyCustomModel(...)
model.save_pretrained("./my_model")
```
The saved `config.json` will include the `auto_map` field.
```json
{
"auto_map": {
"AutoModel": "my_model.MyCustomModel"
}
}
```
> [!NOTE]
> Learn more about implementing custom models in the [Community components](../using-diffusers/custom_pipeline_overview#community-components) guide.

View File

@@ -60,7 +60,7 @@ export_to_video(video.frames[0], "output.mp4", fps=8)
<tr>
<th style="text-align: center;">Face Image</th>
<th style="text-align: center;">Video</th>
<th style="text-align: center;">Description</th
<th style="text-align: center;">Description</th>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/consisid/consisid_image_0.png?download=true" style="height: auto; width: 600px;"></td>

View File

@@ -0,0 +1,133 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Helios
[Helios](https://github.com/PKU-YuanGroup/Helios) is the first 14B video generation model that runs at 19.5 FPS on a single NVIDIA H100 GPU and supports minute-scale generation while matching the quality of a strong baseline, natively integrating T2V, I2V, and V2V tasks within a unified architecture. The main features of Helios are:
- Without commonly used anti-drifting strategies (eg, self-forcing, error-banks, keyframe sampling, or inverted sampling), Helios generates minute-scale videos with high quality and strong coherence.
- Without standard acceleration techniques (eg, KV-cache, causal masking, sparse/linear attention, TinyVAE, progressive noise schedules, hidden-state caching, or quantization), Helios achieves 19.5 FPS in end-to-end inference for a 14B video generation model on a single H100 GPU.
- Introducing optimizations that improve both training and inference throughput while reducing memory consumption. These changes enable training a 14B video generation model without parallelism or sharding infrastructure, with batch sizes comparable to image models.
This guide will walk you through using Helios for use cases.
## Load Model Checkpoints
Model weights may be stored in separate subfolders on the Hub or locally, in which case, you should use the [`~DiffusionPipeline.from_pretrained`] method.
```python
import torch
from diffusers import HeliosPipeline, HeliosPyramidPipeline
from huggingface_hub import snapshot_download
# For Best Quality
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Intermediate Weight
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# For Best Efficiency
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
pipe.to("cuda")
```
## Text-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Image-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Image</th>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Interactive-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Resources
Learn more about Helios with the following resources.
- Watch [video1](https://www.youtube.com/watch?v=vd_AgHtOUFQ) and [video2](https://www.youtube.com/watch?v=1GeIU2Dn7UY) for a demonstration of Helios's key features.
- The research paper, [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/) for more details.

View File

@@ -132,6 +132,8 @@
sections:
- local: using-diffusers/consisid
title: ConsisID
- local: using-diffusers/helios
title: Helios
- title: Resources
isExpanded: false

View File

@@ -26,6 +26,14 @@ http://www.apache.org/licenses/LICENSE-2.0
<th>项目名称</th>
<th>描述</th>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/PKU-YuanGroup/Helios"> helios </a></td>
<td>Helios比1.3B更低开销、更快且更强的14B的实时长视频生成模型</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/PKU-YuanGroup/ConsisID"> consisid </a></td>
<td>ConsisID零样本身份保持的文本到视频生成模型</td>
</tr>
<tr style="border-top: 2px solid black">
<td><a href="https://github.com/carson-katri/dream-textures"> dream-textures </a></td>
<td>Stable Diffusion内置到Blender</td>

View File

@@ -0,0 +1,134 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Helios
[Helios](https://github.com/PKU-YuanGroup/Helios) 是首个能够在单张 NVIDIA H100 GPU 上以 19.5 FPS 运行的 14B 视频生成模型。它在支持分钟级视频生成的同时拥有媲美强大基线模型的生成质量并在统一架构下原生集成了文生视频T2V、图生视频I2V和视频生视频V2V任务。Helios 的主要特性包括:
- 无需常用的防漂移策略(例如:自强制/self-forcing、误差库/error-banks、关键帧采样或逆采样我们的模型即可生成高质量且高度连贯的分钟级视频。
- 无需标准的加速技术例如KV 缓存、因果掩码、稀疏/线性注意力机制、TinyVAE、渐进式噪声调度、隐藏状态缓存或量化作为一款 14B 规模的视频生成模型,我们在单张 H100 GPU 上的端到端推理速度便达到了 19.5 FPS。
- 引入了多项优化方案在降低显存消耗的同时显著提升了训练与推理的吞吐量。这些改进使得我们无需借助并行或分片sharding等基础设施即可使用与图像模型相当的批大小batch sizes来训练 14B 的视频生成模型。
本指南将引导您完成 Helios 在不同场景下的使用。
## Load Model Checkpoints
模型权重可以存储在Hub上或本地的单独子文件夹中在这种情况下您应该使用 [`~DiffusionPipeline.from_pretrained`] 方法。
```python
import torch
from diffusers import HeliosPipeline, HeliosPyramidPipeline
from huggingface_hub import snapshot_download
# For Best Quality
snapshot_download(repo_id="BestWishYsh/Helios-Base", local_dir="BestWishYsh/Helios-Base")
pipe = HeliosPipeline.from_pretrained("BestWishYsh/Helios-Base", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# Intermediate Weight
snapshot_download(repo_id="BestWishYsh/Helios-Mid", local_dir="BestWishYsh/Helios-Mid")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Mid", torch_dtype=torch.bfloat16)
pipe.to("cuda")
# For Best Efficiency
snapshot_download(repo_id="BestWishYsh/Helios-Distilled", local_dir="BestWishYsh/Helios-Distilled")
pipe = HeliosPyramidPipeline.from_pretrained("BestWishYsh/Helios-Distilled", torch_dtype=torch.bfloat16)
pipe.to("cuda")
```
## Text-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>A Viking warrior driving a modern city bus filled with passengers. The Viking has long blonde hair tied back, a beard, and is adorned with a fur-lined helmet and armor. He wears a traditional tunic and trousers, but also sports a seatbelt as he focuses on navigating the busy streets. The interior of the bus is typical, with rows of seats occupied by diverse passengers going about their daily routines. The exterior shots show the bustling urban environment, including tall buildings and traffic. Medium shot focusing on the Viking at the wheel, with occasional close-ups of his determined expression.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>A documentary-style nature photography shot from a camera truck moving to the left, capturing a crab quickly scurrying into its burrow. The crab has a hard, greenish-brown shell and long claws, moving with determined speed across the sandy ground. Its body is slightly arched as it burrows into the sand, leaving a small trail behind. The background shows a shallow beach with scattered rocks and seashells, and the horizon features a gentle curve of the coastline. The photo has a natural and realistic texture, emphasizing the crab's natural movement and the texture of the sand. A close-up shot from a slightly elevated angle.
</small></td>
<td>
<video width="4000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/t2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Image-to-Video Showcases
<table>
<tr>
<th style="text-align: center;">Image</th>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.jpg" style="height: auto; width: 300px;"></td>
<td><small>A sleek red Kia car speeds along a rural road under a cloudy sky, its modern design and dynamic movement emphasized by the blurred motion of the surrounding fields and trees stretching into the distance. The car's glossy exterior reflects the overcast sky, highlighting its aerodynamic shape and sporty stance. The license plate reads "KIA 626," and the vehicle's headlights are on, adding to the sense of motion and energy. The road curves gently, with the car positioned slightly off-center, creating a sense of forward momentum. A dynamic front three-quarter view captures the car's powerful presence against the serene backdrop of rolling hills and scattered trees.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.jpg" style="height: auto; width: 300px;"></td>
<td><small>A close-up captures a fluffy orange cat with striking green eyes and white whiskers, gazing intently towards the camera. The cat's fur is soft and well-groomed, with a mix of warm orange and cream tones. Its large, expressive eyes are a vivid green, reflecting curiosity and alertness. The cat's nose is small and pink, and its mouth is slightly open, revealing a hint of its pink tongue. The background is softly blurred, suggesting a cozy indoor setting with neutral tones. The photo has a shallow depth of field, focusing sharply on the cat's face while the background remains out of focus. A close-up shot from a slightly elevated perspective.
</small></td>
<td>
<video width="2000" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/i2v_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Interactive-Video Showcases
<table>
<tr>
<th style="text-align: center;">Prompt</th>
<th style="text-align: center;">Generated Video</th>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases1.mp4" type="video/mp4">
</video>
</td>
</tr>
<tr>
<td><small>The prompt can be found <a href="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.txt">here</a></small></td>
<td>
<video width="680" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/helios/interactive_showcases2.mp4" type="video/mp4">
</video>
</td>
</tr>
</table>
## Resources
通过以下资源了解有关 Helios 的更多信息:
- [视频1](https://www.youtube.com/watch?v=vd_AgHtOUFQ)和[视频2](https://www.youtube.com/watch?v=1GeIU2Dn7UY)演示了 Helios 的主要功能;
- 有关更多详细信息,请参阅研究论文 [Helios: Real Real-Time Long Video Generation Model](https://huggingface.co/papers/)。

View File

@@ -1232,22 +1232,49 @@ def main(args):
id_token=args.id_token,
)
def encode_video(video, bar):
bar.update(1)
def encode_video(video):
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist
# Distribute video encoding across processes: each process only encodes its own shard
num_videos = len(train_dataset.instance_videos)
num_procs = accelerator.num_processes
local_rank = accelerator.process_index
local_count = len(range(local_rank, num_videos, num_procs))
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
range(local_count),
desc="Encoding videos",
disable=not accelerator.is_local_main_process,
)
train_dataset.instance_videos = [
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
]
encoded_videos = [None] * num_videos
for i, video in enumerate(train_dataset.instance_videos):
if i % num_procs == local_rank:
encoded_videos[i] = encode_video(video)
progress_encode_bar.update(1)
progress_encode_bar.close()
# Broadcast encoded latent distributions so every process has the full set
if num_procs > 1:
import torch.distributed as dist
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
ref_params = next(v for v in encoded_videos if v is not None).parameters
for i in range(num_videos):
src = i % num_procs
if encoded_videos[i] is not None:
params = encoded_videos[i].parameters.contiguous()
else:
params = torch.empty_like(ref_params)
dist.broadcast(params, src=src)
encoded_videos[i] = DiagonalGaussianDistribution(params)
train_dataset.instance_videos = encoded_videos
def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
prompts = [example["instance_prompt"] for example in examples]

View File

@@ -0,0 +1,66 @@
# Training AutoencoderRAE
This example trains the decoder of `AutoencoderRAE` (stage-1 style), while keeping the representation encoder frozen.
It follows the same high-level training recipe as the official RAE stage-1 setup:
- frozen encoder
- train decoder
- pixel reconstruction loss
- optional encoder feature consistency loss
## Quickstart
### Resume or finetune from pretrained weights
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--pretrained_model_name_or_path nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08 \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
### Train from scratch with a pretrained encoder
The following command launches RAE training with "facebook/dinov2-with-registers-base" as the base.
```bash
accelerate launch examples/research_projects/autoencoder_rae/train_autoencoder_rae.py \
--train_data_dir /path/to/imagenet_like_folder \
--output_dir /tmp/autoencoder-rae \
--resolution 256 \
--encoder_type dinov2 \
--encoder_name_or_path facebook/dinov2-with-registers-base \
--encoder_input_size 224 \
--patch_size 16 \
--image_size 256 \
--decoder_hidden_size 1152 \
--decoder_num_hidden_layers 28 \
--decoder_num_attention_heads 16 \
--decoder_intermediate_size 4096 \
--train_batch_size 8 \
--learning_rate 1e-4 \
--num_train_epochs 10 \
--report_to wandb \
--reconstruction_loss_type l1 \
--use_encoder_loss \
--encoder_loss_weight 0.1
```
Note: stage-1 reconstruction loss assumes matching target/output spatial size, so `--resolution` must equal `--image_size`.
Dataset format is expected to be `ImageFolder`-compatible:
```text
train_data_dir/
class_a/
img_0001.jpg
class_b/
img_0002.jpg
```

View File

@@ -0,0 +1,405 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import logging
import math
import os
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm.auto import tqdm
from diffusers import AutoencoderRAE
from diffusers.optimization import get_scheduler
logger = get_logger(__name__)
def parse_args():
parser = argparse.ArgumentParser(description="Train a stage-1 Representation Autoencoder (RAE) decoder.")
parser.add_argument(
"--train_data_dir",
type=str,
required=True,
help="Path to an ImageFolder-style dataset root.",
)
parser.add_argument(
"--output_dir", type=str, default="autoencoder-rae", help="Directory to save checkpoints/model."
)
parser.add_argument("--logging_dir", type=str, default="logs", help="Accelerate logging directory.")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--resolution", type=int, default=256)
parser.add_argument("--center_crop", action="store_true")
parser.add_argument("--random_flip", action="store_true")
parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--dataloader_num_workers", type=int, default=4)
parser.add_argument("--num_train_epochs", type=int, default=10)
parser.add_argument("--max_train_steps", type=int, default=None)
parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--learning_rate", type=float, default=1e-4)
parser.add_argument("--adam_beta1", type=float, default=0.9)
parser.add_argument("--adam_beta2", type=float, default=0.999)
parser.add_argument("--adam_weight_decay", type=float, default=1e-2)
parser.add_argument("--adam_epsilon", type=float, default=1e-8)
parser.add_argument("--lr_scheduler", type=str, default="cosine")
parser.add_argument("--lr_warmup_steps", type=int, default=500)
parser.add_argument("--checkpointing_steps", type=int, default=1000)
parser.add_argument("--validation_steps", type=int, default=500)
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
help="Path to a pretrained AutoencoderRAE model (or HF Hub id) to resume training from.",
)
parser.add_argument(
"--encoder_name_or_path",
type=str,
default=None,
help=(
"HF Hub id or local path of the pretrained encoder (e.g. 'facebook/dinov2-with-registers-base'). "
"When --pretrained_model_name_or_path is not set, the encoder weights are loaded from this path "
"into a freshly constructed AutoencoderRAE. Ignored when --pretrained_model_name_or_path is set."
),
)
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "siglip2", "mae"], default="dinov2")
parser.add_argument("--encoder_hidden_size", type=int, default=768)
parser.add_argument("--encoder_patch_size", type=int, default=14)
parser.add_argument("--encoder_num_hidden_layers", type=int, default=12)
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--decoder_hidden_size", type=int, default=1152)
parser.add_argument("--decoder_num_hidden_layers", type=int, default=28)
parser.add_argument("--decoder_num_attention_heads", type=int, default=16)
parser.add_argument("--decoder_intermediate_size", type=int, default=4096)
parser.add_argument("--noise_tau", type=float, default=0.0)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--reshape_to_2d", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument(
"--reconstruction_loss_type",
type=str,
choices=["l1", "mse"],
default="l1",
help="Pixel reconstruction loss.",
)
parser.add_argument(
"--encoder_loss_weight",
type=float,
default=0.0,
help="Weight for encoder feature consistency loss in the training loop.",
)
parser.add_argument(
"--use_encoder_loss",
action="store_true",
help="Enable encoder feature consistency loss term in the training loop.",
)
parser.add_argument("--report_to", type=str, default="tensorboard")
return parser.parse_args()
def build_transforms(args):
image_transforms = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
]
if args.random_flip:
image_transforms.append(transforms.RandomHorizontalFlip())
image_transforms.append(transforms.ToTensor())
return transforms.Compose(image_transforms)
def compute_losses(
model, pixel_values, reconstruction_loss_type: str, use_encoder_loss: bool, encoder_loss_weight: float
):
decoded = model(pixel_values).sample
if decoded.shape[-2:] != pixel_values.shape[-2:]:
raise ValueError(
"Training requires matching reconstruction and target sizes, got "
f"decoded={tuple(decoded.shape[-2:])}, target={tuple(pixel_values.shape[-2:])}."
)
if reconstruction_loss_type == "l1":
reconstruction_loss = F.l1_loss(decoded.float(), pixel_values.float())
else:
reconstruction_loss = F.mse_loss(decoded.float(), pixel_values.float())
encoder_loss = torch.zeros_like(reconstruction_loss)
if use_encoder_loss and encoder_loss_weight > 0:
base_model = model.module if hasattr(model, "module") else model
target_encoder_input = base_model._resize_and_normalize(pixel_values)
reconstructed_encoder_input = base_model._resize_and_normalize(decoded)
encoder_forward_kwargs = {"model": base_model.encoder}
if base_model.config.encoder_type == "mae":
encoder_forward_kwargs["patch_size"] = base_model.config.encoder_patch_size
with torch.no_grad():
target_tokens = base_model._encoder_forward_fn(images=target_encoder_input, **encoder_forward_kwargs)
reconstructed_tokens = base_model._encoder_forward_fn(
images=reconstructed_encoder_input, **encoder_forward_kwargs
)
encoder_loss = F.mse_loss(reconstructed_tokens.float(), target_tokens.float())
loss = reconstruction_loss + float(encoder_loss_weight) * encoder_loss
return decoded, loss, reconstruction_loss, encoder_loss
def _strip_final_layernorm_affine(state_dict, prefix=""):
"""Remove final layernorm weight/bias so the model keeps its default init (identity)."""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_pretrained_encoder_weights(model, encoder_type, encoder_name_or_path):
"""Load pretrained HF transformers encoder weights into the model's encoder."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_encoder = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
hf_encoder = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
state_dict = {f"vision_model.{k}": v for k, v in hf_encoder.state_dict().items()}
state_dict = _strip_final_layernorm_affine(state_dict, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_encoder = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
state_dict = hf_encoder.state_dict()
state_dict = _strip_final_layernorm_affine(state_dict, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
model.encoder.load_state_dict(state_dict, strict=False)
def main():
args = parse_args()
if args.resolution != args.image_size:
raise ValueError(
f"`--resolution` ({args.resolution}) must match `--image_size` ({args.image_size}) "
"for stage-1 reconstruction loss."
)
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
project_config=accelerator_project_config,
log_with=args.report_to,
)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if args.seed is not None:
set_seed(args.seed)
if accelerator.is_main_process:
os.makedirs(args.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
dataset = ImageFolder(args.train_data_dir, transform=build_transforms(args))
def collate_fn(examples):
pixel_values = torch.stack([example[0] for example in examples]).float()
return {"pixel_values": pixel_values}
train_dataloader = DataLoader(
dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
pin_memory=True,
drop_last=True,
)
if args.pretrained_model_name_or_path is not None:
model = AutoencoderRAE.from_pretrained(args.pretrained_model_name_or_path)
logger.info(f"Loaded pretrained AutoencoderRAE from {args.pretrained_model_name_or_path}")
else:
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=args.encoder_hidden_size,
encoder_patch_size=args.encoder_patch_size,
encoder_num_hidden_layers=args.encoder_num_hidden_layers,
decoder_hidden_size=args.decoder_hidden_size,
decoder_num_hidden_layers=args.decoder_num_hidden_layers,
decoder_num_attention_heads=args.decoder_num_attention_heads,
decoder_intermediate_size=args.decoder_intermediate_size,
patch_size=args.patch_size,
encoder_input_size=args.encoder_input_size,
image_size=args.image_size,
num_channels=args.num_channels,
noise_tau=args.noise_tau,
reshape_to_2d=args.reshape_to_2d,
use_encoder_loss=args.use_encoder_loss,
scaling_factor=args.scaling_factor,
)
if args.encoder_name_or_path is not None:
_load_pretrained_encoder_weights(model, args.encoder_type, args.encoder_name_or_path)
logger.info(f"Loaded pretrained encoder weights from {args.encoder_name_or_path}")
model.encoder.requires_grad_(False)
model.decoder.requires_grad_(True)
model.train()
optimizer = torch.optim.AdamW(
(p for p in model.parameters() if p.requires_grad),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
if overrode_max_train_steps:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
if accelerator.is_main_process:
accelerator.init_trackers("train_autoencoder_rae", config=vars(args))
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
global_step = 0
for epoch in range(args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(model):
pixel_values = batch["pixel_values"]
_, loss, reconstruction_loss, encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
logs = {
"loss": loss.detach().item(),
"reconstruction_loss": reconstruction_loss.detach().item(),
"encoder_loss": encoder_loss.detach().item(),
"lr": lr_scheduler.get_last_lr()[0],
}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step % args.validation_steps == 0:
with torch.no_grad():
_, val_loss, val_reconstruction_loss, val_encoder_loss = compute_losses(
model,
pixel_values,
reconstruction_loss_type=args.reconstruction_loss_type,
use_encoder_loss=args.use_encoder_loss,
encoder_loss_weight=args.encoder_loss_weight,
)
accelerator.log(
{
"val/loss": val_loss.detach().item(),
"val/reconstruction_loss": val_reconstruction_loss.detach().item(),
"val/encoder_loss": val_encoder_loss.detach().item(),
},
step=global_step,
)
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(save_path)
logger.info(f"Saved checkpoint to {save_path}")
if global_step >= args.max_train_steps:
break
if global_step >= args.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(args.output_dir)
accelerator.end_training()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,403 @@
import argparse
from pathlib import Path
from typing import Any
import torch
from huggingface_hub import HfApi, hf_hub_download
from diffusers import AutoencoderRAE
DECODER_CONFIGS = {
"ViTB": {
"decoder_hidden_size": 768,
"decoder_intermediate_size": 3072,
"decoder_num_attention_heads": 12,
"decoder_num_hidden_layers": 12,
},
"ViTL": {
"decoder_hidden_size": 1024,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 24,
},
"ViTXL": {
"decoder_hidden_size": 1152,
"decoder_intermediate_size": 4096,
"decoder_num_attention_heads": 16,
"decoder_num_hidden_layers": 28,
},
}
ENCODER_DEFAULT_NAME_OR_PATH = {
"dinov2": "facebook/dinov2-with-registers-base",
"siglip2": "google/siglip2-base-patch16-256",
"mae": "facebook/vit-mae-base",
}
ENCODER_HIDDEN_SIZE = {
"dinov2": 768,
"siglip2": 768,
"mae": 768,
}
ENCODER_PATCH_SIZE = {
"dinov2": 14,
"siglip2": 16,
"mae": 16,
}
DEFAULT_DECODER_SUBDIR = {
"dinov2": "decoders/dinov2/wReg_base",
"mae": "decoders/mae/base_p16",
"siglip2": "decoders/siglip2/base_p16_i256",
}
DEFAULT_STATS_SUBDIR = {
"dinov2": "stats/dinov2/wReg_base",
"mae": "stats/mae/base_p16",
"siglip2": "stats/siglip2/base_p16_i256",
}
DECODER_FILE_CANDIDATES = ("dinov2_decoder.pt", "model.pt")
STATS_FILE_CANDIDATES = ("stat.pt",)
def dataset_case_candidates(name: str) -> tuple[str, ...]:
return (name, name.lower(), name.upper(), name.title(), "imagenet1k", "ImageNet1k")
class RepoAccessor:
def __init__(self, repo_or_path: str, cache_dir: str | None = None):
self.repo_or_path = repo_or_path
self.cache_dir = cache_dir
self.local_root: Path | None = None
self.repo_id: str | None = None
self.repo_files: set[str] | None = None
root = Path(repo_or_path)
if root.exists() and root.is_dir():
self.local_root = root
else:
self.repo_id = repo_or_path
self.repo_files = set(HfApi().list_repo_files(repo_or_path))
def exists(self, relative_path: str) -> bool:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return (self.local_root / relative_path).is_file()
return relative_path in self.repo_files
def fetch(self, relative_path: str) -> Path:
relative_path = relative_path.replace("\\", "/")
if self.local_root is not None:
return self.local_root / relative_path
downloaded = hf_hub_download(repo_id=self.repo_id, filename=relative_path, cache_dir=self.cache_dir)
return Path(downloaded)
def unwrap_state_dict(maybe_wrapped: dict[str, Any]) -> dict[str, Any]:
state_dict = maybe_wrapped
for k in ("model", "module", "state_dict"):
if isinstance(state_dict, dict) and k in state_dict and isinstance(state_dict[k], dict):
state_dict = state_dict[k]
out = dict(state_dict)
if len(out) > 0 and all(key.startswith("module.") for key in out):
out = {key[len("module.") :]: value for key, value in out.items()}
if len(out) > 0 and all(key.startswith("decoder.") for key in out):
out = {key[len("decoder.") :]: value for key, value in out.items()}
return out
def remap_decoder_attention_keys_for_diffusers(state_dict: dict[str, Any]) -> dict[str, Any]:
"""
Map official RAE decoder attention key layout to diffusers Attention layout used by AutoencoderRAE decoder.
Example mappings:
- `...attention.attention.query.*` -> `...attention.to_q.*`
- `...attention.attention.key.*` -> `...attention.to_k.*`
- `...attention.attention.value.*` -> `...attention.to_v.*`
- `...attention.output.dense.*` -> `...attention.to_out.0.*`
"""
remapped: dict[str, Any] = {}
for key, value in state_dict.items():
new_key = key
new_key = new_key.replace(".attention.attention.query.", ".attention.to_q.")
new_key = new_key.replace(".attention.attention.key.", ".attention.to_k.")
new_key = new_key.replace(".attention.attention.value.", ".attention.to_v.")
new_key = new_key.replace(".attention.output.dense.", ".attention.to_out.0.")
remapped[new_key] = value
return remapped
def resolve_decoder_file(
accessor: RepoAccessor, encoder_type: str, variant: str, decoder_checkpoint: str | None
) -> str:
if decoder_checkpoint is not None:
if accessor.exists(decoder_checkpoint):
return decoder_checkpoint
raise FileNotFoundError(f"Decoder checkpoint not found: {decoder_checkpoint}")
base = f"{DEFAULT_DECODER_SUBDIR[encoder_type]}/{variant}"
for name in DECODER_FILE_CANDIDATES:
candidate = f"{base}/{name}"
if accessor.exists(candidate):
return candidate
raise FileNotFoundError(
f"Could not find decoder checkpoint under `{base}`. Tried: {list(DECODER_FILE_CANDIDATES)}"
)
def resolve_stats_file(
accessor: RepoAccessor,
encoder_type: str,
dataset_name: str,
stats_checkpoint: str | None,
) -> str | None:
if stats_checkpoint is not None:
if accessor.exists(stats_checkpoint):
return stats_checkpoint
raise FileNotFoundError(f"Stats checkpoint not found: {stats_checkpoint}")
base = DEFAULT_STATS_SUBDIR[encoder_type]
for dataset in dataset_case_candidates(dataset_name):
for name in STATS_FILE_CANDIDATES:
candidate = f"{base}/{dataset}/{name}"
if accessor.exists(candidate):
return candidate
return None
def extract_latent_stats(stats_obj: Any) -> tuple[Any | None, Any | None]:
if not isinstance(stats_obj, dict):
return None, None
if "latents_mean" in stats_obj or "latents_std" in stats_obj:
return stats_obj.get("latents_mean", None), stats_obj.get("latents_std", None)
mean = stats_obj.get("mean", None)
var = stats_obj.get("var", None)
if mean is None and var is None:
return None, None
latents_std = None
if var is not None:
if isinstance(var, torch.Tensor):
latents_std = torch.sqrt(var + 1e-5)
else:
latents_std = torch.sqrt(torch.tensor(var) + 1e-5)
return mean, latents_std
def _strip_final_layernorm_affine(state_dict: dict[str, Any], prefix: str = "") -> dict[str, Any]:
"""Remove final layernorm weight/bias from encoder state dict.
RAE uses non-affine layernorm (weight=1, bias=0 is the default identity).
Stripping these keys means the model keeps its default init values, which
is functionally equivalent to setting elementwise_affine=False.
"""
keys_to_strip = {f"{prefix}weight", f"{prefix}bias"}
return {k: v for k, v in state_dict.items() if k not in keys_to_strip}
def _load_hf_encoder_state_dict(encoder_type: str, encoder_name_or_path: str) -> dict[str, Any]:
"""Download the HF encoder and extract the state dict for the inner model."""
if encoder_type == "dinov2":
from transformers import Dinov2WithRegistersModel
hf_model = Dinov2WithRegistersModel.from_pretrained(encoder_name_or_path)
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
elif encoder_type == "siglip2":
from transformers import SiglipModel
# SiglipModel.vision_model is a SiglipVisionTransformer.
# Our Siglip2Encoder wraps it inside SiglipVisionModel which nests it
# under .vision_model, so we add the prefix to match the diffusers key layout.
hf_model = SiglipModel.from_pretrained(encoder_name_or_path).vision_model
sd = {f"vision_model.{k}": v for k, v in hf_model.state_dict().items()}
return _strip_final_layernorm_affine(sd, prefix="vision_model.post_layernorm.")
elif encoder_type == "mae":
from transformers import ViTMAEForPreTraining
hf_model = ViTMAEForPreTraining.from_pretrained(encoder_name_or_path).vit
sd = hf_model.state_dict()
return _strip_final_layernorm_affine(sd, prefix="layernorm.")
else:
raise ValueError(f"Unknown encoder_type: {encoder_type}")
def convert(args: argparse.Namespace) -> None:
accessor = RepoAccessor(args.repo_or_path, cache_dir=args.cache_dir)
encoder_name_or_path = args.encoder_name_or_path or ENCODER_DEFAULT_NAME_OR_PATH[args.encoder_type]
decoder_relpath = resolve_decoder_file(accessor, args.encoder_type, args.variant, args.decoder_checkpoint)
stats_relpath = resolve_stats_file(accessor, args.encoder_type, args.dataset_name, args.stats_checkpoint)
print(f"Using decoder checkpoint: {decoder_relpath}")
if stats_relpath is not None:
print(f"Using stats checkpoint: {stats_relpath}")
else:
print("No stats checkpoint found; conversion will proceed without latent stats.")
if args.dry_run:
return
decoder_path = accessor.fetch(decoder_relpath)
decoder_obj = torch.load(decoder_path, map_location="cpu")
decoder_state_dict = unwrap_state_dict(decoder_obj)
decoder_state_dict = remap_decoder_attention_keys_for_diffusers(decoder_state_dict)
latents_mean, latents_std = None, None
if stats_relpath is not None:
stats_path = accessor.fetch(stats_relpath)
stats_obj = torch.load(stats_path, map_location="cpu")
latents_mean, latents_std = extract_latent_stats(stats_obj)
decoder_cfg = DECODER_CONFIGS[args.decoder_config_name]
# Read encoder normalization stats from the HF image processor (only place that downloads encoder info)
from transformers import AutoConfig, AutoImageProcessor
proc = AutoImageProcessor.from_pretrained(encoder_name_or_path)
encoder_norm_mean = list(proc.image_mean)
encoder_norm_std = list(proc.image_std)
# Read encoder hidden size and patch size from HF config
encoder_hidden_size = ENCODER_HIDDEN_SIZE[args.encoder_type]
encoder_patch_size = ENCODER_PATCH_SIZE[args.encoder_type]
try:
hf_config = AutoConfig.from_pretrained(encoder_name_or_path)
# For models like SigLIP that nest vision config
if hasattr(hf_config, "vision_config"):
hf_config = hf_config.vision_config
encoder_hidden_size = hf_config.hidden_size
encoder_patch_size = hf_config.patch_size
except Exception:
pass
# Load the actual encoder weights from HF to include in the saved model
encoder_state_dict = _load_hf_encoder_state_dict(args.encoder_type, encoder_name_or_path)
# Build model on meta device to avoid double init overhead
with torch.device("meta"):
model = AutoencoderRAE(
encoder_type=args.encoder_type,
encoder_hidden_size=encoder_hidden_size,
encoder_patch_size=encoder_patch_size,
encoder_input_size=args.encoder_input_size,
patch_size=args.patch_size,
image_size=args.image_size,
num_channels=args.num_channels,
encoder_norm_mean=encoder_norm_mean,
encoder_norm_std=encoder_norm_std,
decoder_hidden_size=decoder_cfg["decoder_hidden_size"],
decoder_num_hidden_layers=decoder_cfg["decoder_num_hidden_layers"],
decoder_num_attention_heads=decoder_cfg["decoder_num_attention_heads"],
decoder_intermediate_size=decoder_cfg["decoder_intermediate_size"],
latents_mean=latents_mean,
latents_std=latents_std,
scaling_factor=args.scaling_factor,
)
# Assemble full state dict and load with assign=True
full_state_dict = {}
# Encoder weights (prefixed with "encoder.")
for k, v in encoder_state_dict.items():
full_state_dict[f"encoder.{k}"] = v
# Decoder weights (prefixed with "decoder.")
for k, v in decoder_state_dict.items():
full_state_dict[f"decoder.{k}"] = v
# Buffers from config
full_state_dict["encoder_mean"] = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
full_state_dict["encoder_std"] = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
if latents_mean is not None:
latents_mean_t = latents_mean if isinstance(latents_mean, torch.Tensor) else torch.tensor(latents_mean)
full_state_dict["_latents_mean"] = latents_mean_t
else:
full_state_dict["_latents_mean"] = torch.zeros(1)
if latents_std is not None:
latents_std_t = latents_std if isinstance(latents_std, torch.Tensor) else torch.tensor(latents_std)
full_state_dict["_latents_std"] = latents_std_t
else:
full_state_dict["_latents_std"] = torch.ones(1)
model.load_state_dict(full_state_dict, strict=False, assign=True)
# Verify no critical keys are missing
model_keys = {name for name, _ in model.named_parameters()}
model_keys |= {name for name, _ in model.named_buffers()}
loaded_keys = set(full_state_dict.keys())
missing = model_keys - loaded_keys
# trainable_cls_token and decoder_pos_embed are initialized, not loaded from original checkpoint
allowed_missing = {"decoder.trainable_cls_token", "decoder.decoder_pos_embed"}
if missing - allowed_missing:
print(f"Warning: missing keys after conversion: {sorted(missing - allowed_missing)}")
output_path = Path(args.output_path)
output_path.mkdir(parents=True, exist_ok=True)
model.save_pretrained(output_path)
if args.verify_load:
print("Verifying converted checkpoint with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False)...")
loaded_model = AutoencoderRAE.from_pretrained(output_path, low_cpu_mem_usage=False)
if not isinstance(loaded_model, AutoencoderRAE):
raise RuntimeError("Verification failed: loaded object is not AutoencoderRAE.")
print("Verification passed.")
print(f"Saved converted AutoencoderRAE to: {output_path}")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert RAE decoder checkpoints to diffusers AutoencoderRAE format")
parser.add_argument(
"--repo_or_path", type=str, required=True, help="Hub repo id (e.g. nyu-visionx/RAE-collections) or local path"
)
parser.add_argument("--output_path", type=str, required=True, help="Directory to save converted model")
parser.add_argument("--encoder_type", type=str, choices=["dinov2", "mae", "siglip2"], required=True)
parser.add_argument(
"--encoder_name_or_path", type=str, default=None, help="Optional encoder HF model id or local path override"
)
parser.add_argument("--variant", type=str, default="ViTXL_n08", help="Decoder variant folder name")
parser.add_argument("--dataset_name", type=str, default="imagenet1k", help="Stats dataset folder name")
parser.add_argument(
"--decoder_checkpoint", type=str, default=None, help="Relative path to decoder checkpoint inside repo/path"
)
parser.add_argument(
"--stats_checkpoint", type=str, default=None, help="Relative path to stats checkpoint inside repo/path"
)
parser.add_argument("--decoder_config_name", type=str, choices=list(DECODER_CONFIGS.keys()), default="ViTXL")
parser.add_argument("--encoder_input_size", type=int, default=224)
parser.add_argument("--patch_size", type=int, default=16)
parser.add_argument("--image_size", type=int, default=None)
parser.add_argument("--num_channels", type=int, default=3)
parser.add_argument("--scaling_factor", type=float, default=1.0)
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--dry_run", action="store_true", help="Only resolve and print selected files")
parser.add_argument(
"--verify_load",
action="store_true",
help="After conversion, load back with AutoencoderRAE.from_pretrained(low_cpu_mem_usage=False).",
)
return parser.parse_args()
def main() -> None:
args = parse_args()
convert(args)
if __name__ == "__main__":
main()

View File

@@ -202,6 +202,7 @@ else:
"AutoencoderKLTemporalDecoder",
"AutoencoderKLWan",
"AutoencoderOobleck",
"AutoencoderRAE",
"AutoencoderTiny",
"AutoModel",
"BriaFiboTransformer2DModel",
@@ -227,6 +228,7 @@ else:
"FluxMultiControlNetModel",
"FluxTransformer2DModel",
"GlmImageTransformer2DModel",
"HeliosTransformer3DModel",
"HiDreamImageTransformer2DModel",
"HunyuanDiT2DControlNetModel",
"HunyuanDiT2DModel",
@@ -359,6 +361,8 @@ else:
"FlowMatchEulerDiscreteScheduler",
"FlowMatchHeunDiscreteScheduler",
"FlowMatchLCMScheduler",
"HeliosDMDScheduler",
"HeliosScheduler",
"HeunDiscreteScheduler",
"IPNDMScheduler",
"KarrasVeScheduler",
@@ -515,6 +519,8 @@ else:
"FluxPipeline",
"FluxPriorReduxPipeline",
"GlmImagePipeline",
"HeliosPipeline",
"HeliosPyramidPipeline",
"HiDreamImagePipeline",
"HunyuanDiTControlNetPipeline",
"HunyuanDiTPAGPipeline",
@@ -969,6 +975,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
AutoModel,
BriaFiboTransformer2DModel,
@@ -994,6 +1001,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxMultiControlNetModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DControlNetModel,
HunyuanDiT2DModel,
@@ -1122,6 +1130,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FlowMatchEulerDiscreteScheduler,
FlowMatchHeunDiscreteScheduler,
FlowMatchLCMScheduler,
HeliosDMDScheduler,
HeliosScheduler,
HeunDiscreteScheduler,
IPNDMScheduler,
KarrasVeScheduler,
@@ -1257,6 +1267,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
FluxPipeline,
FluxPriorReduxPipeline,
GlmImagePipeline,
HeliosPipeline,
HeliosPyramidPipeline,
HiDreamImagePipeline,
HunyuanDiTControlNetPipeline,
HunyuanDiTPAGPipeline,

View File

@@ -89,8 +89,6 @@ class CustomBlocksCommand(BaseDiffusersCLICommand):
# automap = self._create_automap(parent_class=parent_class, child_class=child_class)
# with open(CONFIG, "w") as f:
# json.dump(automap, f)
with open("requirements.txt", "w") as f:
f.write("")
def _choose_block(self, candidates, chosen=None):
for cls, base in candidates:

View File

@@ -107,6 +107,38 @@ class ConfigMixin:
has_compatibles = False
_deprecated_kwargs = []
_auto_class = None
@classmethod
def register_for_auto_class(cls, auto_class="AutoModel"):
"""
Register this class with the given auto class so that it can be loaded with `AutoModel.from_pretrained(...,
trust_remote_code=True)`.
When the config is saved, the resulting `config.json` will include an `auto_map` entry mapping the auto class
to this class's module and class name.
Args:
auto_class (`str` or type, *optional*, defaults to `"AutoModel"`):
The auto class to register this class with. Can be a string (e.g. `"AutoModel"`) or the class itself.
Currently only `"AutoModel"` is supported.
Example:
```python
from diffusers import ModelMixin, ConfigMixin
class MyCustomModel(ModelMixin, ConfigMixin): ...
MyCustomModel.register_for_auto_class("AutoModel")
```
"""
if auto_class != "AutoModel":
raise ValueError(f"Only 'AutoModel' is supported, got '{auto_class}'.")
cls._auto_class = auto_class
def register_to_config(self, **kwargs):
if self.config_name is None:
@@ -621,6 +653,12 @@ class ConfigMixin:
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
_ = config_dict.pop("_pre_quantization_dtype", None)
if getattr(self, "_auto_class", None) is not None:
module = self.__class__.__module__.split(".")[-1]
auto_map = config_dict.get("auto_map", {})
auto_map[self._auto_class] = f"{module}.{self.__class__.__name__}"
config_dict["auto_map"] = auto_map
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: str | os.PathLike):
@@ -648,28 +686,6 @@ class ConfigMixin:
)
return config_file
@classmethod
def _get_dataclass_from_config(cls, config_dict: dict[str, Any]):
sig = inspect.signature(cls.__init__)
fields = []
for name, param in sig.parameters.items():
if name == "self" or name == "kwargs" or name in cls.ignore_for_config:
continue
annotation = param.annotation if param.annotation is not inspect.Parameter.empty else Any
if param.default is not inspect.Parameter.empty:
fields.append((name, annotation, dataclasses.field(default=param.default)))
else:
fields.append((name, annotation))
dc_cls = dataclasses.make_dataclass(
f"{cls.__name__}Config",
fields,
frozen=True,
)
valid_fields = {f.name for f in dataclasses.fields(dc_cls)}
init_kwargs = {k: v for k, v in config_dict.items() if k in valid_fields}
return dc_cls(**init_kwargs)
def register_to_config(init):
r"""

View File

@@ -307,6 +307,17 @@ class GroupOffloadingHook(ModelHook):
if self.group.onload_leader == module:
if self.group.onload_self:
self.group.onload_()
else:
# onload_self=False means this group relies on prefetching from a previous group.
# However, for conditionally-executed modules (e.g. patch_short/patch_mid/patch_long in Helios),
# the prefetch chain may not cover them if they were absent during the first forward pass
# when the execution order was traced. In that case, their weights remain on offload_device,
# so we fall back to a synchronous onload here.
params = [p for m in self.group.modules for p in m.parameters()] + list(self.group.parameters)
if params and params[0].device == self.group.offload_device:
self.group.onload_()
if self.group.stream is not None:
self.group.stream.synchronize()
should_onload_next_group = self.next_group is not None and not self.next_group.onload_self
if should_onload_next_group:

View File

@@ -78,6 +78,7 @@ if is_torch_available():
"SanaLoraLoaderMixin",
"Lumina2LoraLoaderMixin",
"WanLoraLoaderMixin",
"HeliosLoraLoaderMixin",
"KandinskyLoraLoaderMixin",
"HiDreamImageLoraLoaderMixin",
"SkyReelsV2LoraLoaderMixin",
@@ -118,6 +119,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
CogView4LoraLoaderMixin,
Flux2LoraLoaderMixin,
FluxLoraLoaderMixin,
HeliosLoraLoaderMixin,
HiDreamImageLoraLoaderMixin,
HunyuanVideoLoraLoaderMixin,
KandinskyLoraLoaderMixin,

View File

@@ -2519,6 +2519,13 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
if has_default:
state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()}
# Normalize ZImage-specific dot-separated module names to underscore form so they
# match the diffusers model parameter names (context_refiner, noise_refiner).
state_dict = {
k.replace("context.refiner.", "context_refiner.").replace("noise.refiner.", "noise_refiner."): v
for k, v in state_dict.items()
}
converted_state_dict = {}
all_keys = list(state_dict.keys())
down_key = ".lora_down.weight"
@@ -2529,19 +2536,18 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
if has_non_diffusers_lora_id:
def get_alpha_scales(down_weight, alpha_key):
rank = down_weight.shape[0]
alpha = state_dict.pop(alpha_key).item()
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
return scale_down, scale_up
for k in all_keys:
if k.endswith(down_key):
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
@@ -2554,13 +2560,69 @@ def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
converted_state_dict[diffusers_down_key] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Already in diffusers format (lora_A/lora_B), just pop
# Already in diffusers format (lora_A/lora_B), apply alpha scaling and pop.
elif has_diffusers_lora_id:
for k in all_keys:
if a_key in k or b_key in k:
converted_state_dict[k] = state_dict.pop(k)
elif ".alpha" in k:
if k.endswith(a_key):
diffusers_up_key = k.replace(a_key, b_key)
alpha_key = k.replace(a_key, ".alpha")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(diffusers_up_key)
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[k] = down_weight * scale_down
converted_state_dict[diffusers_up_key] = up_weight * scale_up
# Handle dot-format LoRA keys: ".lora.down.weight" / ".lora.up.weight".
# Some external ZImage trainers (e.g. Anime-Z) use dots instead of underscores in
# lora weight names and also include redundant keys:
# - "qkv.lora.*" duplicates individual "to.q/k/v.lora.*" keys → skip qkv
# - "out.lora.*" duplicates "to_out.0.lora.*" keys → skip bare out
# - "to.q/k/v.lora.*" → normalise to "to_q/k/v.lora_A/B.weight"
lora_dot_down_key = ".lora.down.weight"
lora_dot_up_key = ".lora.up.weight"
has_lora_dot_format = any(lora_dot_down_key in k for k in state_dict)
if has_lora_dot_format:
dot_keys = list(state_dict.keys())
for k in dot_keys:
if lora_dot_down_key not in k:
continue
if k not in state_dict:
continue # already popped by a prior iteration
base = k[: -len(lora_dot_down_key)]
# Skip combined "qkv" projection — individual to.q/k/v keys are also present.
if base.endswith(".qkv"):
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
state_dict.pop(base + ".alpha", None)
continue
# Skip bare "out.lora.*" — "to_out.0.lora.*" covers the same projection.
if re.search(r"\.out$", base) and ".to_out" not in base:
state_dict.pop(k)
state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key), None)
continue
# Normalise "to.q/k/v" → "to_q/k/v" for the diffusers output key.
norm_k = re.sub(
r"\.to\.([qkv])" + re.escape(lora_dot_down_key) + r"$",
r".to_\1" + lora_dot_down_key,
k,
)
norm_base = norm_k[: -len(lora_dot_down_key)]
alpha_key = norm_base + ".alpha"
diffusers_down = norm_k.replace(lora_dot_down_key, ".lora_A.weight")
diffusers_up = norm_k.replace(lora_dot_down_key, ".lora_B.weight")
down_weight = state_dict.pop(k)
up_weight = state_dict.pop(k.replace(lora_dot_down_key, lora_dot_up_key))
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
converted_state_dict[diffusers_down] = down_weight * scale_down
converted_state_dict[diffusers_up] = up_weight * scale_up
if len(state_dict) > 0:
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

View File

@@ -3440,6 +3440,207 @@ class SanaLoraLoaderMixin(LoraBaseMixin):
super().unfuse_lora(components=components, **kwargs)
class HeliosLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HeliosTransformer3DModel`]. Specific to [`HeliosPipeline`] and [`HeliosPyramidPipeline`].
"""
_lora_loadable_modules = ["transformer"]
transformer_name = TRANSFORMER_NAME
@classmethod
@validate_hf_hub_args
def lora_state_dict(
cls,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details.
"""
# Load the main state dict first which has the LoRA layers for either of
# transformer and text encoder or both.
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
return_lora_metadata = kwargs.pop("return_lora_metadata", False)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True
allow_pickle = True
user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
state_dict, metadata = _fetch_state_dict(
pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
weight_name=weight_name,
use_safetensors=use_safetensors,
local_files_only=local_files_only,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
token=token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
allow_pickle=allow_pickle,
)
if any(k.startswith("diffusion_model.") for k in state_dict):
state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
elif any(k.startswith("lora_unet_") for k in state_dict):
state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
if is_dora_scale_present:
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
def load_lora_weights(
self,
pretrained_model_name_or_path_or_dict: str | dict[str, torch.Tensor],
adapter_name: str | None = None,
hotswap: bool = False,
**kwargs,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for more details.
"""
if not USE_PEFT_BACKEND:
raise ValueError("PEFT backend is required for this method.")
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# if a dict is passed, copy it instead of modifying it inplace
if isinstance(pretrained_model_name_or_path_or_dict, dict):
pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
kwargs["return_lora_metadata"] = True
state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
is_correct_format = all("lora" in key for key in state_dict.keys())
if not is_correct_format:
raise ValueError("Invalid LoRA checkpoint. Make sure all LoRA param names contain `'lora'` substring.")
self.load_lora_into_transformer(
state_dict,
transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=self,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
def load_lora_into_transformer(
cls,
state_dict,
transformer,
adapter_name=None,
_pipeline=None,
low_cpu_mem_usage=False,
hotswap: bool = False,
metadata=None,
):
"""
See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_unet`] for more details.
"""
if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
raise ValueError(
"`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
)
# Load the layers corresponding to transformer.
logger.info(f"Loading {cls.transformer_name}.")
transformer.load_lora_adapter(
state_dict,
network_alphas=None,
adapter_name=adapter_name,
metadata=metadata,
_pipeline=_pipeline,
low_cpu_mem_usage=low_cpu_mem_usage,
hotswap=hotswap,
)
@classmethod
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
def save_lora_weights(
cls,
save_directory: str | os.PathLike,
transformer_lora_layers: dict[str, torch.nn.Module | torch.Tensor] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = True,
transformer_lora_adapter_metadata: dict | None = None,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for more information.
"""
lora_layers = {}
lora_metadata = {}
if transformer_lora_layers:
lora_layers[cls.transformer_name] = transformer_lora_layers
lora_metadata[cls.transformer_name] = transformer_lora_adapter_metadata
if not lora_layers:
raise ValueError("You must pass at least one of `transformer_lora_layers` or `text_encoder_lora_layers`.")
cls._save_lora_weights(
save_directory=save_directory,
lora_layers=lora_layers,
lora_metadata=lora_metadata,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
def fuse_lora(
self,
components: list[str] = ["transformer"],
lora_scale: float = 1.0,
safe_fusing: bool = False,
adapter_names: list[str] | None = None,
**kwargs,
):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.fuse_lora`] for more details.
"""
super().fuse_lora(
components=components,
lora_scale=lora_scale,
safe_fusing=safe_fusing,
adapter_names=adapter_names,
**kwargs,
)
# Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
def unfuse_lora(self, components: list[str] = ["transformer"], **kwargs):
r"""
See [`~loaders.StableDiffusionLoraLoaderMixin.unfuse_lora`] for more details.
"""
super().unfuse_lora(components=components, **kwargs)
class HunyuanVideoLoraLoaderMixin(LoraBaseMixin):
r"""
Load LoRA layers into [`HunyuanVideoTransformer3DModel`]. Specific to [`HunyuanVideoPipeline`].

View File

@@ -51,6 +51,7 @@ _SET_ADAPTER_SCALE_FN_MAPPING = {
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,

View File

@@ -49,6 +49,7 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_rae"] = ["AutoencoderRAE"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
@@ -100,6 +101,7 @@ if is_torch_available():
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
_import_structure["transformers.transformer_flux2"] = ["Flux2Transformer2DModel"]
_import_structure["transformers.transformer_glm_image"] = ["GlmImageTransformer2DModel"]
_import_structure["transformers.transformer_helios"] = ["HeliosTransformer3DModel"]
_import_structure["transformers.transformer_hidream_image"] = ["HiDreamImageTransformer2DModel"]
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
_import_structure["transformers.transformer_hunyuan_video15"] = ["HunyuanVideo15Transformer3DModel"]
@@ -167,6 +169,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLTemporalDecoder,
AutoencoderKLWan,
AutoencoderOobleck,
AutoencoderRAE,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
@@ -212,6 +215,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
Flux2Transformer2DModel,
FluxTransformer2DModel,
GlmImageTransformer2DModel,
HeliosTransformer3DModel,
HiDreamImageTransformer2DModel,
HunyuanDiT2DModel,
HunyuanImageTransformer2DModel,

View File

@@ -18,6 +18,7 @@ from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_kl_wan import AutoencoderKLWan
from .autoencoder_oobleck import AutoencoderOobleck
from .autoencoder_rae import AutoencoderRAE
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .vq_model import VQModel

View File

@@ -0,0 +1,692 @@
# Copyright 2026 The NYU Vision-X and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from math import sqrt
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput, logging
from ...utils.accelerate_utils import apply_forward_hook
from ...utils.import_utils import is_transformers_available
from ...utils.torch_utils import randn_tensor
if is_transformers_available():
from transformers import (
Dinov2WithRegistersConfig,
Dinov2WithRegistersModel,
SiglipVisionConfig,
SiglipVisionModel,
ViTMAEConfig,
ViTMAEModel,
)
from ..activations import get_activation
from ..attention import AttentionMixin
from ..attention_processor import Attention
from ..embeddings import get_2d_sincos_pos_embed
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, EncoderOutput
logger = logging.get_logger(__name__)
# ---------------------------------------------------------------------------
# Per-encoder forward functions
# ---------------------------------------------------------------------------
# Each function takes the raw transformers model + images and returns patch
# tokens of shape (B, N, C), stripping CLS / register tokens as needed.
def _dinov2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True)
unused_token_num = 5 # 1 CLS + 4 register tokens
return outputs.last_hidden_state[:, unused_token_num:]
def _siglip2_encoder_forward(model: nn.Module, images: torch.Tensor) -> torch.Tensor:
outputs = model(images, output_hidden_states=True, interpolate_pos_encoding=True)
return outputs.last_hidden_state
def _mae_encoder_forward(model: nn.Module, images: torch.Tensor, patch_size: int) -> torch.Tensor:
h, w = images.shape[2], images.shape[3]
patch_num = int(h * w // patch_size**2)
if patch_num * patch_size**2 != h * w:
raise ValueError("Image size should be divisible by patch size.")
noise = torch.arange(patch_num).unsqueeze(0).expand(images.shape[0], -1).to(images.device).to(images.dtype)
outputs = model(images, noise, interpolate_pos_encoding=True)
return outputs.last_hidden_state[:, 1:] # remove cls token
# ---------------------------------------------------------------------------
# Encoder construction helpers
# ---------------------------------------------------------------------------
def _build_encoder(
encoder_type: str, hidden_size: int, patch_size: int, num_hidden_layers: int, head_dim: int = 64
) -> nn.Module:
"""Build a frozen encoder from config (no pretrained download)."""
num_attention_heads = hidden_size // head_dim # all supported encoders use head_dim=64
if encoder_type == "dinov2":
config = Dinov2WithRegistersConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=518,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = Dinov2WithRegistersModel(config)
# RAE strips the final layernorm affine params (identity LN). Remove them from
# the architecture so `from_pretrained` doesn't leave them on the meta device.
model.layernorm.weight = None
model.layernorm.bias = None
elif encoder_type == "siglip2":
config = SiglipVisionConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=256,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
)
model = SiglipVisionModel(config)
# See dinov2 comment above.
model.vision_model.post_layernorm.weight = None
model.vision_model.post_layernorm.bias = None
elif encoder_type == "mae":
config = ViTMAEConfig(
hidden_size=hidden_size,
patch_size=patch_size,
image_size=224,
num_attention_heads=num_attention_heads,
num_hidden_layers=num_hidden_layers,
mask_ratio=0.0,
)
model = ViTMAEModel(config)
# See dinov2 comment above.
model.layernorm.weight = None
model.layernorm.bias = None
else:
raise ValueError(f"Unknown encoder_type='{encoder_type}'. Available: dinov2, siglip2, mae")
model.requires_grad_(False)
return model
_ENCODER_FORWARD_FNS = {
"dinov2": _dinov2_encoder_forward,
"siglip2": _siglip2_encoder_forward,
"mae": _mae_encoder_forward,
}
@dataclass
class RAEDecoderOutput(BaseOutput):
"""
Output of `RAEDecoder`.
Args:
logits (`torch.Tensor`):
Patch reconstruction logits of shape `(batch_size, num_patches, patch_size**2 * num_channels)`.
"""
logits: torch.Tensor
class ViTMAEIntermediate(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_act: str = "gelu"):
super().__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
self.intermediate_act_fn = get_activation(hidden_act)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
class ViTMAEOutput(nn.Module):
def __init__(self, hidden_size: int, intermediate_size: int, hidden_dropout_prob: float = 0.0):
super().__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states + input_tensor
return hidden_states
class ViTMAELayer(nn.Module):
"""
This matches the naming/parameter structure used in RAE-main (ViTMAE decoder block).
"""
def __init__(
self,
*,
hidden_size: int,
num_attention_heads: int,
intermediate_size: int,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
f"hidden_size={hidden_size} must be divisible by num_attention_heads={num_attention_heads}"
)
self.attention = Attention(
query_dim=hidden_size,
heads=num_attention_heads,
dim_head=hidden_size // num_attention_heads,
dropout=attention_probs_dropout_prob,
bias=qkv_bias,
)
self.intermediate = ViTMAEIntermediate(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_act=hidden_act
)
self.output = ViTMAEOutput(
hidden_size=hidden_size, intermediate_size=intermediate_size, hidden_dropout_prob=hidden_dropout_prob
)
self.layernorm_before = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.layernorm_after = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attention_output = self.attention(self.layernorm_before(hidden_states))
hidden_states = attention_output + hidden_states
layer_output = self.layernorm_after(hidden_states)
layer_output = self.intermediate(layer_output)
layer_output = self.output(layer_output, hidden_states)
return layer_output
class RAEDecoder(nn.Module):
"""Lightweight RAE decoder."""
def __init__(
self,
hidden_size: int = 768,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
num_patches: int = 256,
patch_size: int = 16,
num_channels: int = 3,
image_size: int = 256,
qkv_bias: bool = True,
layer_norm_eps: float = 1e-12,
hidden_dropout_prob: float = 0.0,
attention_probs_dropout_prob: float = 0.0,
hidden_act: str = "gelu",
):
super().__init__()
self.decoder_hidden_size = decoder_hidden_size
self.patch_size = patch_size
self.num_channels = num_channels
self.image_size = image_size
self.num_patches = num_patches
self.decoder_embed = nn.Linear(hidden_size, decoder_hidden_size, bias=True)
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_hidden_size))
self.decoder_layers = nn.ModuleList(
[
ViTMAELayer(
hidden_size=decoder_hidden_size,
num_attention_heads=decoder_num_attention_heads,
intermediate_size=decoder_intermediate_size,
qkv_bias=qkv_bias,
layer_norm_eps=layer_norm_eps,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
hidden_act=hidden_act,
)
for _ in range(decoder_num_hidden_layers)
]
)
self.decoder_norm = nn.LayerNorm(decoder_hidden_size, eps=layer_norm_eps)
self.decoder_pred = nn.Linear(decoder_hidden_size, patch_size**2 * num_channels, bias=True)
self.gradient_checkpointing = False
self._initialize_weights(num_patches)
self.trainable_cls_token = nn.Parameter(torch.zeros(1, 1, decoder_hidden_size))
def _initialize_weights(self, num_patches: int):
# Skip initialization when parameters are on meta device (e.g. during
# accelerate.init_empty_weights() used by low_cpu_mem_usage loading).
# The weights are initialized.
if self.decoder_pos_embed.device.type == "meta":
return
grid_size = int(num_patches**0.5)
pos_embed = get_2d_sincos_pos_embed(
self.decoder_pos_embed.shape[-1],
grid_size,
cls_token=True,
extra_tokens=1,
output_type="pt",
device=self.decoder_pos_embed.device,
)
self.decoder_pos_embed.data.copy_(pos_embed.unsqueeze(0).to(dtype=self.decoder_pos_embed.dtype))
def interpolate_pos_encoding(self, embeddings: torch.Tensor) -> torch.Tensor:
embeddings_positions = embeddings.shape[1] - 1
num_positions = self.decoder_pos_embed.shape[1] - 1
class_pos_embed = self.decoder_pos_embed[:, 0, :]
patch_pos_embed = self.decoder_pos_embed[:, 1:, :]
dim = self.decoder_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.reshape(1, 1, -1, dim).permute(0, 3, 1, 2)
patch_pos_embed = F.interpolate(
patch_pos_embed,
scale_factor=(1, embeddings_positions / num_positions),
mode="bicubic",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def interpolate_latent(self, x: torch.Tensor) -> torch.Tensor:
b, l, c = x.shape
if l == self.num_patches:
return x
h = w = int(l**0.5)
x = x.reshape(b, h, w, c).permute(0, 3, 1, 2)
target_size = (int(self.num_patches**0.5), int(self.num_patches**0.5))
x = F.interpolate(x, size=target_size, mode="bilinear", align_corners=False)
x = x.permute(0, 2, 3, 1).contiguous().view(b, self.num_patches, c)
return x
def unpatchify(self, patchified_pixel_values: torch.Tensor, original_image_size: tuple[int, int] | None = None):
patch_size, num_channels = self.patch_size, self.num_channels
original_image_size = (
original_image_size if original_image_size is not None else (self.image_size, self.image_size)
)
original_height, original_width = original_image_size
num_patches_h = original_height // patch_size
num_patches_w = original_width // patch_size
if num_patches_h * num_patches_w != patchified_pixel_values.shape[1]:
raise ValueError(
f"The number of patches in the patchified pixel values {patchified_pixel_values.shape[1]}, does not match the number of patches on original image {num_patches_h}*{num_patches_w}"
)
batch_size = patchified_pixel_values.shape[0]
patchified_pixel_values = patchified_pixel_values.reshape(
batch_size,
num_patches_h,
num_patches_w,
patch_size,
patch_size,
num_channels,
)
patchified_pixel_values = torch.einsum("nhwpqc->nchpwq", patchified_pixel_values)
pixel_values = patchified_pixel_values.reshape(
batch_size,
num_channels,
num_patches_h * patch_size,
num_patches_w * patch_size,
)
return pixel_values
def forward(
self,
hidden_states: torch.Tensor,
*,
interpolate_pos_encoding: bool = False,
drop_cls_token: bool = False,
return_dict: bool = True,
) -> RAEDecoderOutput | tuple[torch.Tensor]:
x = self.decoder_embed(hidden_states)
if drop_cls_token:
x_ = x[:, 1:, :]
x_ = self.interpolate_latent(x_)
else:
x_ = self.interpolate_latent(x)
cls_token = self.trainable_cls_token.expand(x_.shape[0], -1, -1)
x = torch.cat([cls_token, x_], dim=1)
if interpolate_pos_encoding:
if not drop_cls_token:
raise ValueError("interpolate_pos_encoding only supports drop_cls_token=True")
decoder_pos_embed = self.interpolate_pos_encoding(x)
else:
decoder_pos_embed = self.decoder_pos_embed
hidden_states = x + decoder_pos_embed.to(device=x.device, dtype=x.dtype)
for layer_module in self.decoder_layers:
hidden_states = layer_module(hidden_states)
hidden_states = self.decoder_norm(hidden_states)
logits = self.decoder_pred(hidden_states)
logits = logits[:, 1:, :]
if not return_dict:
return (logits,)
return RAEDecoderOutput(logits=logits)
class AutoencoderRAE(ModelMixin, AttentionMixin, AutoencoderMixin, ConfigMixin):
r"""
Representation Autoencoder (RAE) model for encoding images to latents and decoding latents to images.
This model uses a frozen pretrained encoder (DINOv2, SigLIP2, or MAE) with a trainable ViT decoder to reconstruct
images from learned representations.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Args:
encoder_type (`str`, *optional*, defaults to `"dinov2"`):
Type of frozen encoder to use. One of `"dinov2"`, `"siglip2"`, or `"mae"`.
encoder_hidden_size (`int`, *optional*, defaults to `768`):
Hidden size of the encoder model.
encoder_patch_size (`int`, *optional*, defaults to `14`):
Patch size of the encoder model.
encoder_num_hidden_layers (`int`, *optional*, defaults to `12`):
Number of hidden layers in the encoder model.
patch_size (`int`, *optional*, defaults to `16`):
Decoder patch size (used for unpatchify and decoder head).
encoder_input_size (`int`, *optional*, defaults to `224`):
Input size expected by the encoder.
image_size (`int`, *optional*):
Decoder output image size. If `None`, it is derived from encoder token count and `patch_size` like
RAE-main: `image_size = patch_size * sqrt(num_patches)`, where `num_patches = (encoder_input_size //
encoder_patch_size) ** 2`.
num_channels (`int`, *optional*, defaults to `3`):
Number of input/output channels.
encoder_norm_mean (`list`, *optional*, defaults to `[0.485, 0.456, 0.406]`):
Channel-wise mean for encoder input normalization (ImageNet defaults).
encoder_norm_std (`list`, *optional*, defaults to `[0.229, 0.224, 0.225]`):
Channel-wise std for encoder input normalization (ImageNet defaults).
latents_mean (`list` or `tuple`, *optional*):
Optional mean for latent normalization. Tensor inputs are accepted and converted to config-serializable
lists.
latents_std (`list` or `tuple`, *optional*):
Optional standard deviation for latent normalization. Tensor inputs are accepted and converted to
config-serializable lists.
noise_tau (`float`, *optional*, defaults to `0.0`):
Noise level for training (adds noise to latents during training).
reshape_to_2d (`bool`, *optional*, defaults to `True`):
Whether to reshape latents to 2D (B, C, H, W) format.
use_encoder_loss (`bool`, *optional*, defaults to `False`):
Whether to use encoder hidden states in the loss (for advanced training).
"""
# NOTE: gradient checkpointing is not wired up for this model yet.
_supports_gradient_checkpointing = False
_no_split_modules = ["ViTMAELayer"]
@register_to_config
def __init__(
self,
encoder_type: str = "dinov2",
encoder_hidden_size: int = 768,
encoder_patch_size: int = 14,
encoder_num_hidden_layers: int = 12,
decoder_hidden_size: int = 512,
decoder_num_hidden_layers: int = 8,
decoder_num_attention_heads: int = 16,
decoder_intermediate_size: int = 2048,
patch_size: int = 16,
encoder_input_size: int = 224,
image_size: int | None = None,
num_channels: int = 3,
encoder_norm_mean: list | None = None,
encoder_norm_std: list | None = None,
latents_mean: list | tuple | torch.Tensor | None = None,
latents_std: list | tuple | torch.Tensor | None = None,
noise_tau: float = 0.0,
reshape_to_2d: bool = True,
use_encoder_loss: bool = False,
scaling_factor: float = 1.0,
):
super().__init__()
if encoder_type not in _ENCODER_FORWARD_FNS:
raise ValueError(
f"Unknown encoder_type='{encoder_type}'. Available: {sorted(_ENCODER_FORWARD_FNS.keys())}"
)
if encoder_input_size % encoder_patch_size != 0:
raise ValueError(
f"encoder_input_size={encoder_input_size} must be divisible by encoder_patch_size={encoder_patch_size}."
)
decoder_patch_size = patch_size
if decoder_patch_size <= 0:
raise ValueError("patch_size must be a positive integer (this is decoder_patch_size).")
num_patches = (encoder_input_size // encoder_patch_size) ** 2
grid = int(sqrt(num_patches))
if grid * grid != num_patches:
raise ValueError(f"Computed num_patches={num_patches} must be a perfect square.")
derived_image_size = decoder_patch_size * grid
if image_size is None:
image_size = derived_image_size
else:
image_size = int(image_size)
if image_size != derived_image_size:
raise ValueError(
f"image_size={image_size} must equal decoder_patch_size*sqrt(num_patches)={derived_image_size} "
f"for patch_size={decoder_patch_size} and computed num_patches={num_patches}."
)
def _to_config_compatible(value: Any) -> Any:
if isinstance(value, torch.Tensor):
return value.detach().cpu().tolist()
if isinstance(value, tuple):
return [_to_config_compatible(v) for v in value]
if isinstance(value, list):
return [_to_config_compatible(v) for v in value]
return value
def _as_optional_tensor(value: torch.Tensor | list | tuple | None) -> torch.Tensor | None:
if value is None:
return None
if isinstance(value, torch.Tensor):
return value.detach().clone()
return torch.tensor(value, dtype=torch.float32)
latents_std_tensor = _as_optional_tensor(latents_std)
# Ensure config values are JSON-serializable (list/None), even if caller passes torch.Tensors.
self.register_to_config(
latents_mean=_to_config_compatible(latents_mean),
latents_std=_to_config_compatible(latents_std),
)
# Frozen representation encoder (built from config, no downloads)
self.encoder: nn.Module = _build_encoder(
encoder_type=encoder_type,
hidden_size=encoder_hidden_size,
patch_size=encoder_patch_size,
num_hidden_layers=encoder_num_hidden_layers,
)
self._encoder_forward_fn = _ENCODER_FORWARD_FNS[encoder_type]
num_patches = (encoder_input_size // encoder_patch_size) ** 2
# Encoder input normalization stats (ImageNet defaults)
if encoder_norm_mean is None:
encoder_norm_mean = [0.485, 0.456, 0.406]
if encoder_norm_std is None:
encoder_norm_std = [0.229, 0.224, 0.225]
encoder_mean_tensor = torch.tensor(encoder_norm_mean, dtype=torch.float32).view(1, 3, 1, 1)
encoder_std_tensor = torch.tensor(encoder_norm_std, dtype=torch.float32).view(1, 3, 1, 1)
self.register_buffer("encoder_mean", encoder_mean_tensor, persistent=True)
self.register_buffer("encoder_std", encoder_std_tensor, persistent=True)
# Latent normalization buffers (defaults are no-ops; actual values come from checkpoint)
latents_mean_tensor = _as_optional_tensor(latents_mean)
if latents_mean_tensor is None:
latents_mean_tensor = torch.zeros(1)
self.register_buffer("_latents_mean", latents_mean_tensor, persistent=True)
if latents_std_tensor is None:
latents_std_tensor = torch.ones(1)
self.register_buffer("_latents_std", latents_std_tensor, persistent=True)
# ViT-MAE style decoder
self.decoder = RAEDecoder(
hidden_size=int(encoder_hidden_size),
decoder_hidden_size=int(decoder_hidden_size),
decoder_num_hidden_layers=int(decoder_num_hidden_layers),
decoder_num_attention_heads=int(decoder_num_attention_heads),
decoder_intermediate_size=int(decoder_intermediate_size),
num_patches=int(num_patches),
patch_size=int(decoder_patch_size),
num_channels=int(num_channels),
image_size=int(image_size),
)
self.num_patches = int(num_patches)
self.decoder_patch_size = int(decoder_patch_size)
self.decoder_image_size = int(image_size)
# Slicing support (batch dimension) similar to other diffusers autoencoders
self.use_slicing = False
def _noising(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
# Per-sample random sigma in [0, noise_tau]
noise_sigma = self.config.noise_tau * torch.rand(
(x.size(0),) + (1,) * (x.ndim - 1), device=x.device, dtype=x.dtype, generator=generator
)
return x + noise_sigma * randn_tensor(x.shape, generator=generator, device=x.device, dtype=x.dtype)
def _resize_and_normalize(self, x: torch.Tensor) -> torch.Tensor:
_, _, h, w = x.shape
if h != self.config.encoder_input_size or w != self.config.encoder_input_size:
x = F.interpolate(
x,
size=(self.config.encoder_input_size, self.config.encoder_input_size),
mode="bicubic",
align_corners=False,
)
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return (x - mean) / std
def _denormalize_image(self, x: torch.Tensor) -> torch.Tensor:
mean = self.encoder_mean.to(device=x.device, dtype=x.dtype)
std = self.encoder_std.to(device=x.device, dtype=x.dtype)
return x * std + mean
def _normalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return (z - latents_mean) / (latents_std + 1e-5)
def _denormalize_latents(self, z: torch.Tensor) -> torch.Tensor:
latents_mean = self._latents_mean.to(device=z.device, dtype=z.dtype)
latents_std = self._latents_std.to(device=z.device, dtype=z.dtype)
return z * (latents_std + 1e-5) + latents_mean
def _encode(self, x: torch.Tensor, generator: torch.Generator | None = None) -> torch.Tensor:
x = self._resize_and_normalize(x)
if self.config.encoder_type == "mae":
tokens = self._encoder_forward_fn(self.encoder, x, self.config.encoder_patch_size)
else:
tokens = self._encoder_forward_fn(self.encoder, x) # (B, N, C)
if self.training and self.config.noise_tau > 0:
tokens = self._noising(tokens, generator=generator)
if self.config.reshape_to_2d:
b, n, c = tokens.shape
side = int(sqrt(n))
if side * side != n:
raise ValueError(f"Token length n={n} is not a perfect square; cannot reshape to 2D.")
z = tokens.transpose(1, 2).contiguous().view(b, c, side, side) # (B, C, h, w)
else:
z = tokens
z = self._normalize_latents(z)
# Follow diffusers convention: optionally scale latents for diffusion
if self.config.scaling_factor != 1.0:
z = z * self.config.scaling_factor
return z
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> EncoderOutput | tuple[torch.Tensor]:
if self.use_slicing and x.shape[0] > 1:
latents = torch.cat([self._encode(x_slice, generator=generator) for x_slice in x.split(1)], dim=0)
else:
latents = self._encode(x, generator=generator)
if not return_dict:
return (latents,)
return EncoderOutput(latent=latents)
def _decode(self, z: torch.Tensor) -> torch.Tensor:
# Undo scaling factor if applied at encode time
if self.config.scaling_factor != 1.0:
z = z / self.config.scaling_factor
z = self._denormalize_latents(z)
if self.config.reshape_to_2d:
b, c, h, w = z.shape
tokens = z.view(b, c, h * w).transpose(1, 2).contiguous() # (B, N, C)
else:
tokens = z
logits = self.decoder(tokens, return_dict=True).logits
x_rec = self.decoder.unpatchify(logits)
x_rec = self._denormalize_image(x_rec)
return x_rec
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | tuple[torch.Tensor]:
if self.use_slicing and z.shape[0] > 1:
decoded = torch.cat([self._decode(z_slice) for z_slice in z.split(1)], dim=0)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self, sample: torch.Tensor, return_dict: bool = True, generator: torch.Generator | None = None
) -> DecoderOutput | tuple[torch.Tensor]:
latents = self.encode(sample, return_dict=False, generator=generator)[0]
decoded = self.decode(latents, return_dict=False)[0]
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)

View File

@@ -28,6 +28,7 @@ if is_torch_available():
from .transformer_flux import FluxTransformer2DModel
from .transformer_flux2 import Flux2Transformer2DModel
from .transformer_glm_image import GlmImageTransformer2DModel
from .transformer_helios import HeliosTransformer3DModel
from .transformer_hidream_image import HiDreamImageTransformer2DModel
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
from .transformer_hunyuan_video15 import HunyuanVideo15Transformer3DModel

View File

@@ -0,0 +1,814 @@
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from functools import lru_cache
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, logging
from ...utils.torch_utils import maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, AttentionModuleMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def pad_for_3d_conv(x, kernel_size):
b, c, t, h, w = x.shape
pt, ph, pw = kernel_size
pad_t = (pt - (t % pt)) % pt
pad_h = (ph - (h % ph)) % ph
pad_w = (pw - (w % pw)) % pw
return torch.nn.functional.pad(x, (0, pad_w, 0, pad_h, 0, pad_t), mode="replicate")
def center_down_sample_3d(x, kernel_size):
return torch.nn.functional.avg_pool3d(x, kernel_size, stride=kernel_size)
def apply_rotary_emb_transposed(
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
):
x_1, x_2 = hidden_states.unflatten(-1, (-1, 2)).unbind(-1)
cos, sin = freqs_cis.unsqueeze(-2).chunk(2, dim=-1)
out = torch.empty_like(hidden_states)
out[..., 0::2] = x_1 * cos[..., 0::2] - x_2 * sin[..., 1::2]
out[..., 1::2] = x_1 * sin[..., 1::2] + x_2 * cos[..., 0::2]
return out.type_as(hidden_states)
def _get_qkv_projections(attn: "HeliosAttention", hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor):
# encoder_hidden_states is only passed for cross-attention
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if attn.fused_projections:
if not attn.is_cross_attention:
# In self-attention layers, we can fuse the entire QKV projection into a single linear
query, key, value = attn.to_qkv(hidden_states).chunk(3, dim=-1)
else:
# In cross-attention layers, we can only fuse the KV projections into a single linear
query = attn.to_q(hidden_states)
key, value = attn.to_kv(encoder_hidden_states).chunk(2, dim=-1)
else:
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
return query, key, value
class HeliosOutputNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine: bool = False):
super().__init__()
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
self.norm = FP32LayerNorm(dim, eps, elementwise_affine=False)
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, original_context_length: int):
temb = temb[:, -original_context_length:, :]
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
shift, scale = shift.squeeze(2).to(hidden_states.device), scale.squeeze(2).to(hidden_states.device)
hidden_states = hidden_states[:, -original_context_length:, :]
hidden_states = (self.norm(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
return hidden_states
class HeliosAttnProcessor:
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError(
"HeliosAttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to version 2.0 or higher."
)
def __call__(
self,
attn: "HeliosAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
original_context_length: int = None,
) -> torch.Tensor:
query, key, value = _get_qkv_projections(attn, hidden_states, encoder_hidden_states)
query = attn.norm_q(query)
key = attn.norm_k(key)
query = query.unflatten(2, (attn.heads, -1))
key = key.unflatten(2, (attn.heads, -1))
value = value.unflatten(2, (attn.heads, -1))
if rotary_emb is not None:
query = apply_rotary_emb_transposed(query, rotary_emb)
key = apply_rotary_emb_transposed(key, rotary_emb)
if not attn.is_cross_attention and attn.is_amplify_history:
history_seq_len = hidden_states.shape[1] - original_context_length
if history_seq_len > 0:
scale_key = 1.0 + torch.sigmoid(attn.history_key_scale) * (attn.max_scale - 1.0)
if attn.history_scale_mode == "per_head":
scale_key = scale_key.view(1, 1, -1, 1)
key = torch.cat([key[:, :history_seq_len] * scale_key, key[:, history_seq_len:]], dim=1)
hidden_states = dispatch_attention_fn(
query,
key,
value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
backend=self._attention_backend,
# Reference: https://github.com/huggingface/diffusers/pull/12909
parallel_config=(self._parallel_config if encoder_hidden_states is None else None),
)
hidden_states = hidden_states.flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class HeliosAttention(torch.nn.Module, AttentionModuleMixin):
_default_processor_cls = HeliosAttnProcessor
_available_processors = [HeliosAttnProcessor]
def __init__(
self,
dim: int,
heads: int = 8,
dim_head: int = 64,
eps: float = 1e-5,
dropout: float = 0.0,
added_kv_proj_dim: int | None = None,
cross_attention_dim_head: int | None = None,
processor=None,
is_cross_attention=None,
is_amplify_history=False,
history_scale_mode="per_head", # [scalar, per_head]
):
super().__init__()
self.inner_dim = dim_head * heads
self.heads = heads
self.added_kv_proj_dim = added_kv_proj_dim
self.cross_attention_dim_head = cross_attention_dim_head
self.kv_inner_dim = self.inner_dim if cross_attention_dim_head is None else cross_attention_dim_head * heads
self.to_q = torch.nn.Linear(dim, self.inner_dim, bias=True)
self.to_k = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_v = torch.nn.Linear(dim, self.kv_inner_dim, bias=True)
self.to_out = torch.nn.ModuleList(
[
torch.nn.Linear(self.inner_dim, dim, bias=True),
torch.nn.Dropout(dropout),
]
)
self.norm_q = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.norm_k = torch.nn.RMSNorm(dim_head * heads, eps=eps, elementwise_affine=True)
self.add_k_proj = self.add_v_proj = None
if added_kv_proj_dim is not None:
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=True)
self.norm_added_k = torch.nn.RMSNorm(dim_head * heads, eps=eps)
if is_cross_attention is not None:
self.is_cross_attention = is_cross_attention
else:
self.is_cross_attention = cross_attention_dim_head is not None
self.set_processor(processor)
self.is_amplify_history = is_amplify_history
if is_amplify_history:
if history_scale_mode == "scalar":
self.history_key_scale = nn.Parameter(torch.ones(1))
elif history_scale_mode == "per_head":
self.history_key_scale = nn.Parameter(torch.ones(heads))
else:
raise ValueError(f"Unknown history_scale_mode: {history_scale_mode}")
self.history_scale_mode = history_scale_mode
self.max_scale = 10.0
def fuse_projections(self):
if getattr(self, "fused_projections", False):
return
if not self.is_cross_attention:
concatenated_weights = torch.cat([self.to_q.weight.data, self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_q.bias.data, self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_qkv = nn.Linear(in_features, out_features, bias=True)
self.to_qkv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
else:
concatenated_weights = torch.cat([self.to_k.weight.data, self.to_v.weight.data])
concatenated_bias = torch.cat([self.to_k.bias.data, self.to_v.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_kv = nn.Linear(in_features, out_features, bias=True)
self.to_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
if self.added_kv_proj_dim is not None:
concatenated_weights = torch.cat([self.add_k_proj.weight.data, self.add_v_proj.weight.data])
concatenated_bias = torch.cat([self.add_k_proj.bias.data, self.add_v_proj.bias.data])
out_features, in_features = concatenated_weights.shape
with torch.device("meta"):
self.to_added_kv = nn.Linear(in_features, out_features, bias=True)
self.to_added_kv.load_state_dict(
{"weight": concatenated_weights, "bias": concatenated_bias}, strict=True, assign=True
)
self.fused_projections = True
@torch.no_grad()
def unfuse_projections(self):
if not getattr(self, "fused_projections", False):
return
if hasattr(self, "to_qkv"):
delattr(self, "to_qkv")
if hasattr(self, "to_kv"):
delattr(self, "to_kv")
if hasattr(self, "to_added_kv"):
delattr(self, "to_added_kv")
self.fused_projections = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
rotary_emb: tuple[torch.Tensor, torch.Tensor] | None = None,
original_context_length: int = None,
**kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states,
attention_mask,
rotary_emb,
original_context_length,
**kwargs,
)
class HeliosTimeTextEmbedding(nn.Module):
def __init__(
self,
dim: int,
time_freq_dim: int,
time_proj_dim: int,
text_embed_dim: int,
):
super().__init__()
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
def forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor | None = None,
is_return_encoder_hidden_states: bool = True,
):
timestep = self.timesteps_proj(timestep)
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
if encoder_hidden_states is not None and is_return_encoder_hidden_states:
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
return temb, timestep_proj, encoder_hidden_states
class HeliosRotaryPosEmbed(nn.Module):
def __init__(self, rope_dim, theta):
super().__init__()
self.DT, self.DY, self.DX = rope_dim
self.theta = theta
self.register_buffer("freqs_base_t", self._get_freqs_base(self.DT), persistent=False)
self.register_buffer("freqs_base_y", self._get_freqs_base(self.DY), persistent=False)
self.register_buffer("freqs_base_x", self._get_freqs_base(self.DX), persistent=False)
def _get_freqs_base(self, dim):
return 1.0 / (self.theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: (dim // 2)] / dim))
@torch.no_grad()
def get_frequency_batched(self, freqs_base, pos):
freqs = torch.einsum("d,bthw->dbthw", freqs_base, pos)
freqs = freqs.repeat_interleave(2, dim=0)
return freqs.cos(), freqs.sin()
@torch.no_grad()
@lru_cache(maxsize=32)
def _get_spatial_meshgrid(self, height, width, device_str):
device = torch.device(device_str)
grid_y_coords = torch.arange(height, device=device, dtype=torch.float32)
grid_x_coords = torch.arange(width, device=device, dtype=torch.float32)
grid_y, grid_x = torch.meshgrid(grid_y_coords, grid_x_coords, indexing="ij")
return grid_y, grid_x
@torch.no_grad()
def forward(self, frame_indices, height, width, device):
batch_size = frame_indices.shape[0]
num_frames = frame_indices.shape[1]
frame_indices = frame_indices.to(device=device, dtype=torch.float32)
grid_y, grid_x = self._get_spatial_meshgrid(height, width, str(device))
grid_t = frame_indices[:, :, None, None].expand(batch_size, num_frames, height, width)
grid_y_batch = grid_y[None, None, :, :].expand(batch_size, num_frames, -1, -1)
grid_x_batch = grid_x[None, None, :, :].expand(batch_size, num_frames, -1, -1)
freqs_cos_t, freqs_sin_t = self.get_frequency_batched(self.freqs_base_t, grid_t)
freqs_cos_y, freqs_sin_y = self.get_frequency_batched(self.freqs_base_y, grid_y_batch)
freqs_cos_x, freqs_sin_x = self.get_frequency_batched(self.freqs_base_x, grid_x_batch)
result = torch.cat([freqs_cos_t, freqs_cos_y, freqs_cos_x, freqs_sin_t, freqs_sin_y, freqs_sin_x], dim=0)
return result.permute(1, 0, 2, 3, 4)
@maybe_allow_in_graph
class HeliosTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
ffn_dim: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
cross_attn_norm: bool = False,
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
guidance_cross_attn: bool = False,
is_amplify_history: bool = False,
history_scale_mode: str = "per_head", # [scalar, per_head]
):
super().__init__()
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = HeliosAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
cross_attention_dim_head=None,
processor=HeliosAttnProcessor(),
is_amplify_history=is_amplify_history,
history_scale_mode=history_scale_mode,
)
# 2. Cross-attention
self.attn2 = HeliosAttention(
dim=dim,
heads=num_heads,
dim_head=dim // num_heads,
eps=eps,
added_kv_proj_dim=added_kv_proj_dim,
cross_attention_dim_head=dim // num_heads,
processor=HeliosAttnProcessor(),
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
# 4. Guidance cross-attention
self.guidance_cross_attn = guidance_cross_attn
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
original_context_length: int = None,
) -> torch.Tensor:
if temb.ndim == 4:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table.unsqueeze(0) + temb.float()
).chunk(6, dim=2)
# batch_size, seq_len, 1, inner_dim
shift_msa = shift_msa.squeeze(2)
scale_msa = scale_msa.squeeze(2)
gate_msa = gate_msa.squeeze(2)
c_shift_msa = c_shift_msa.squeeze(2)
c_scale_msa = c_scale_msa.squeeze(2)
c_gate_msa = c_gate_msa.squeeze(2)
else:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
None,
None,
rotary_emb,
original_context_length,
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
# 2. Cross-attention
if self.guidance_cross_attn:
history_seq_len = hidden_states.shape[1] - original_context_length
history_hidden_states, hidden_states = torch.split(
hidden_states, [history_seq_len, original_context_length], dim=1
)
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
None,
None,
original_context_length,
)
hidden_states = hidden_states + attn_output
hidden_states = torch.cat([history_hidden_states, hidden_states], dim=1)
else:
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states,
None,
None,
original_context_length,
)
hidden_states = hidden_states + attn_output
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
return hidden_states
class HeliosTransformer3DModel(
ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin, AttentionMixin
):
r"""
A Transformer model for video-like data used in the Helios model.
Args:
patch_size (`tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
in_channels (`int`, defaults to `16`):
The number of channels in the input.
out_channels (`int`, defaults to `16`):
The number of channels in the output.
text_dim (`int`, defaults to `512`):
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
The number of layers of transformer blocks to use.
window_size (`tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
Whether to use img_emb.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
"""
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = [
"patch_embedding",
"patch_short",
"patch_mid",
"patch_long",
"condition_embedder",
"norm",
]
_no_split_modules = ["HeliosTransformerBlock", "HeliosOutputNorm"]
_keep_in_fp32_modules = [
"time_embedder",
"scale_shift_table",
"norm1",
"norm2",
"norm3",
"history_key_scale",
]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_repeated_blocks = ["HeliosTransformerBlock"]
_cp_plan = {
"blocks.0": {
"hidden_states": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.*": {
"temb": ContextParallelInput(split_dim=1, expected_dims=4, split_output=False),
"rotary_emb": ContextParallelInput(split_dim=1, expected_dims=3, split_output=False),
},
"blocks.39": ContextParallelOutput(gather_dim=1, expected_dims=3),
}
@register_to_config
def __init__(
self,
patch_size: tuple[int, ...] = (1, 2, 2),
num_attention_heads: int = 40,
attention_head_dim: int = 128,
in_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
freq_dim: int = 256,
ffn_dim: int = 13824,
num_layers: int = 40,
cross_attn_norm: bool = True,
qk_norm: str | None = "rms_norm_across_heads",
eps: float = 1e-6,
added_kv_proj_dim: int | None = None,
rope_dim: tuple[int, ...] = (44, 42, 42),
rope_theta: float = 10000.0,
guidance_cross_attn: bool = True,
zero_history_timestep: bool = True,
has_multi_term_memory_patch: bool = True,
is_amplify_history: bool = False,
history_scale_mode: str = "per_head", # [scalar, per_head]
) -> None:
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
# 1. Patch & position embedding
self.rope = HeliosRotaryPosEmbed(rope_dim=rope_dim, theta=rope_theta)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
# 2. Initial Multi Term Memory Patch
self.zero_history_timestep = zero_history_timestep
if has_multi_term_memory_patch:
self.patch_short = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.patch_mid = nn.Conv3d(
in_channels,
inner_dim,
kernel_size=tuple(2 * p for p in patch_size),
stride=tuple(2 * p for p in patch_size),
)
self.patch_long = nn.Conv3d(
in_channels,
inner_dim,
kernel_size=tuple(4 * p for p in patch_size),
stride=tuple(4 * p for p in patch_size),
)
# 3. Condition embeddings
self.condition_embedder = HeliosTimeTextEmbedding(
dim=inner_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
)
# 4. Transformer blocks
self.blocks = nn.ModuleList(
[
HeliosTransformerBlock(
inner_dim,
ffn_dim,
num_attention_heads,
qk_norm,
cross_attn_norm,
eps,
added_kv_proj_dim,
guidance_cross_attn=guidance_cross_attn,
is_amplify_history=is_amplify_history,
history_scale_mode=history_scale_mode,
)
for _ in range(num_layers)
]
)
# 5. Output norm & projection
self.norm_out = HeliosOutputNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.gradient_checkpointing = False
@apply_lora_scale("attention_kwargs")
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
# ------------ Stage 1 ------------
indices_hidden_states=None,
indices_latents_history_short=None,
indices_latents_history_mid=None,
indices_latents_history_long=None,
latents_history_short=None,
latents_history_mid=None,
latents_history_long=None,
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
) -> torch.Tensor | dict[str, torch.Tensor]:
# 1. Input
batch_size = hidden_states.shape[0]
p_t, p_h, p_w = self.config.patch_size
# 2. Process noisy latents
hidden_states = self.patch_embedding(hidden_states)
_, _, post_patch_num_frames, post_patch_height, post_patch_width = hidden_states.shape
if indices_hidden_states is None:
indices_hidden_states = torch.arange(0, post_patch_num_frames).unsqueeze(0).expand(batch_size, -1)
hidden_states = hidden_states.flatten(2).transpose(1, 2)
rotary_emb = self.rope(
frame_indices=indices_hidden_states,
height=post_patch_height,
width=post_patch_width,
device=hidden_states.device,
)
rotary_emb = rotary_emb.flatten(2).transpose(1, 2)
original_context_length = hidden_states.shape[1]
# 3. Process short history latents
if latents_history_short is not None and indices_latents_history_short is not None:
latents_history_short = self.patch_short(latents_history_short)
_, _, _, H1, W1 = latents_history_short.shape
latents_history_short = latents_history_short.flatten(2).transpose(1, 2)
rotary_emb_history_short = self.rope(
frame_indices=indices_latents_history_short,
height=H1,
width=W1,
device=latents_history_short.device,
)
rotary_emb_history_short = rotary_emb_history_short.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_short, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_short, rotary_emb], dim=1)
# 4. Process mid history latents
if latents_history_mid is not None and indices_latents_history_mid is not None:
latents_history_mid = pad_for_3d_conv(latents_history_mid, (2, 4, 4))
latents_history_mid = self.patch_mid(latents_history_mid)
latents_history_mid = latents_history_mid.flatten(2).transpose(1, 2)
rotary_emb_history_mid = self.rope(
frame_indices=indices_latents_history_mid,
height=H1,
width=W1,
device=latents_history_mid.device,
)
rotary_emb_history_mid = pad_for_3d_conv(rotary_emb_history_mid, (2, 2, 2))
rotary_emb_history_mid = center_down_sample_3d(rotary_emb_history_mid, (2, 2, 2))
rotary_emb_history_mid = rotary_emb_history_mid.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_mid, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_mid, rotary_emb], dim=1)
# 5. Process long history latents
if latents_history_long is not None and indices_latents_history_long is not None:
latents_history_long = pad_for_3d_conv(latents_history_long, (4, 8, 8))
latents_history_long = self.patch_long(latents_history_long)
latents_history_long = latents_history_long.flatten(2).transpose(1, 2)
rotary_emb_history_long = self.rope(
frame_indices=indices_latents_history_long,
height=H1,
width=W1,
device=latents_history_long.device,
)
rotary_emb_history_long = pad_for_3d_conv(rotary_emb_history_long, (4, 4, 4))
rotary_emb_history_long = center_down_sample_3d(rotary_emb_history_long, (4, 4, 4))
rotary_emb_history_long = rotary_emb_history_long.flatten(2).transpose(1, 2)
hidden_states = torch.cat([latents_history_long, hidden_states], dim=1)
rotary_emb = torch.cat([rotary_emb_history_long, rotary_emb], dim=1)
history_context_length = hidden_states.shape[1] - original_context_length
if indices_hidden_states is not None and self.zero_history_timestep:
timestep_t0 = torch.zeros((1), dtype=timestep.dtype, device=timestep.device)
temb_t0, timestep_proj_t0, _ = self.condition_embedder(
timestep_t0, encoder_hidden_states, is_return_encoder_hidden_states=False
)
temb_t0 = temb_t0.unsqueeze(1).expand(batch_size, history_context_length, -1)
timestep_proj_t0 = (
timestep_proj_t0.unflatten(-1, (6, -1))
.view(1, 6, 1, -1)
.expand(batch_size, -1, history_context_length, -1)
)
temb, timestep_proj, encoder_hidden_states = self.condition_embedder(timestep, encoder_hidden_states)
timestep_proj = timestep_proj.unflatten(-1, (6, -1))
if indices_hidden_states is not None and not self.zero_history_timestep:
main_repeat_size = hidden_states.shape[1]
else:
main_repeat_size = original_context_length
temb = temb.view(batch_size, 1, -1).expand(batch_size, main_repeat_size, -1)
timestep_proj = timestep_proj.view(batch_size, 6, 1, -1).expand(batch_size, 6, main_repeat_size, -1)
if indices_hidden_states is not None and self.zero_history_timestep:
temb = torch.cat([temb_t0, temb], dim=1)
timestep_proj = torch.cat([timestep_proj_t0, timestep_proj], dim=2)
if timestep_proj.ndim == 4:
timestep_proj = timestep_proj.permute(0, 2, 1, 3)
# 6. Transformer blocks
hidden_states = hidden_states.contiguous()
encoder_hidden_states = encoder_hidden_states.contiguous()
rotary_emb = rotary_emb.contiguous()
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
block,
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
original_context_length,
)
else:
for block in self.blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states,
timestep_proj,
rotary_emb,
original_context_length,
)
# 7. Normalization
hidden_states = self.norm_out(hidden_states, temb, original_context_length)
hidden_states = self.proj_out(hidden_states)
# 8. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -14,6 +14,7 @@
import importlib
import inspect
import os
import sys
import traceback
import warnings
from collections import OrderedDict
@@ -28,10 +29,16 @@ from tqdm.auto import tqdm
from typing_extensions import Self
from ..configuration_utils import ConfigMixin, FrozenDict
from ..pipelines.pipeline_loading_utils import _fetch_class_library_tuple, simple_get_class_obj
from ..pipelines.pipeline_loading_utils import (
LOADABLE_CLASSES,
_fetch_class_library_tuple,
_unwrap_model,
simple_get_class_obj,
)
from ..utils import PushToHubMixin, is_accelerate_available, logging
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
from ..utils.torch_utils import is_compiled_module
from .components_manager import ComponentsManager
from .modular_pipeline_utils import (
MODULAR_MODEL_CARD_TEMPLATE,
@@ -40,6 +47,7 @@ from .modular_pipeline_utils import (
InputParam,
InsertableDict,
OutputParam,
_validate_requirements,
combine_inputs,
combine_outputs,
format_components,
@@ -290,6 +298,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
config_name = "modular_config.json"
model_name = None
_requirements: dict[str, str] | None = None
_workflow_map = None
@classmethod
@@ -404,6 +413,9 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
"Selected model repository does not happear to have any custom code or does not have a valid `config.json` file."
)
if "requirements" in config and config["requirements"] is not None:
_ = _validate_requirements(config["requirements"])
class_ref = config["auto_map"][cls.__name__]
module_file, class_name = class_ref.split(".")
module_file = module_file + ".py"
@@ -428,8 +440,13 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
module = full_mod.rsplit(".", 1)[-1].replace("__dynamic__", "")
parent_module = self.save_pretrained.__func__.__qualname__.split(".", 1)[0]
auto_map = {f"{parent_module}": f"{module}.{cls_name}"}
self.register_to_config(auto_map=auto_map)
# resolve requirements
requirements = _validate_requirements(getattr(self, "_requirements", None))
if requirements:
self.register_to_config(requirements=requirements)
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs)
config = dict(self.config)
self._internal_dict = FrozenDict(config)
@@ -651,6 +668,15 @@ class ConditionalPipelineBlocks(ModularPipelineBlocks):
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
# used for `__repr__`
def _get_trigger_inputs(self) -> set:
"""
@@ -1240,6 +1266,14 @@ class SequentialPipelineBlocks(ModularPipelineBlocks):
expected_configs=self.expected_configs,
)
@property
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
"""
@@ -1378,6 +1412,15 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks):
def outputs(self) -> list[str]:
return next(reversed(self.sub_blocks.values())).intermediate_outputs
@property
# Copied from diffusers.modular_pipelines.modular_pipeline.SequentialPipelineBlocks._requirements
def _requirements(self) -> dict[str, str]:
requirements = {}
for block_name, block in self.sub_blocks.items():
if getattr(block, "_requirements", None):
requirements[block_name] = block._requirements
return requirements
def __init__(self):
sub_blocks = InsertableDict()
for block_name, block in zip(self.block_names, self.block_classes):
@@ -1700,6 +1743,8 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
_blocks_class_name=self._blocks.__class__.__name__ if self._blocks is not None else None
)
self._pretrained_model_name_or_path = pretrained_model_name_or_path
@property
def default_call_parameters(self) -> dict[str, Any]:
"""
@@ -1826,44 +1871,136 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
)
return pipeline
def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs):
def save_pretrained(
self,
save_directory: str | os.PathLike,
safe_serialization: bool = True,
variant: str | None = None,
max_shard_size: int | str | None = None,
push_to_hub: bool = False,
**kwargs,
):
"""
Save the pipeline to a directory. It does not save components, you need to save them separately.
Save the pipeline and all its components to a directory, so that it can be re-loaded using the
[`~ModularPipeline.from_pretrained`] class method.
Args:
save_directory (`str` or `os.PathLike`):
Path to the directory where the pipeline will be saved.
push_to_hub (`bool`, optional):
Whether to push the pipeline to the huggingface hub.
**kwargs: Additional arguments passed to `save_config()` method
Directory to save the pipeline to. Will be created if it doesn't exist.
safe_serialization (`bool`, *optional*, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
variant (`str`, *optional*):
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
max_shard_size (`int` or `str`, defaults to `None`):
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
If expressed as an integer, the unit is bytes.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether to push the pipeline to the Hugging Face model hub after saving it.
**kwargs: Additional keyword arguments:
- `overwrite_modular_index` (`bool`, *optional*, defaults to `False`):
When saving a Modular Pipeline, its components in `modular_model_index.json` may reference repos
different from the destination repo. Setting this to `True` updates all component references in
`modular_model_index.json` so they point to the repo specified by `repo_id`.
- `repo_id` (`str`, *optional*):
The repository ID to push the pipeline to. Defaults to the last component of `save_directory`.
- `commit_message` (`str`, *optional*):
Commit message for the push to hub operation.
- `private` (`bool`, *optional*):
Whether the repository should be private.
- `create_pr` (`bool`, *optional*, defaults to `False`):
Whether to create a pull request instead of pushing directly.
- `token` (`str`, *optional*):
The Hugging Face token to use for authentication.
"""
overwrite_modular_index = kwargs.pop("overwrite_modular_index", False)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
if push_to_hub:
commit_message = kwargs.pop("commit_message", None)
private = kwargs.pop("private", None)
create_pr = kwargs.pop("create_pr", False)
token = kwargs.pop("token", None)
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
update_model_card = kwargs.pop("update_model_card", False)
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
# Generate modular pipeline card content
card_content = generate_modular_model_card_content(self.blocks)
for component_name, component_spec in self._component_specs.items():
if component_spec.default_creation_method != "from_pretrained":
continue
# Create a new empty model card and eventually tag it
component = getattr(self, component_name, None)
if component is None:
continue
model_cls = component.__class__
if is_compiled_module(component):
component = _unwrap_model(component)
model_cls = component.__class__
save_method_name = None
for library_name, library_classes in LOADABLE_CLASSES.items():
if library_name in sys.modules:
library = importlib.import_module(library_name)
else:
logger.info(
f"{library_name} is not installed. Cannot save {component_name} as {library_classes} from {library_name}"
)
continue
for base_class, save_load_methods in library_classes.items():
class_candidate = getattr(library, base_class, None)
if class_candidate is not None and issubclass(model_cls, class_candidate):
save_method_name = save_load_methods[0]
break
if save_method_name is not None:
break
if save_method_name is None:
logger.warning(f"self.{component_name}={component} of type {type(component)} cannot be saved.")
continue
save_method = getattr(component, save_method_name)
save_method_signature = inspect.signature(save_method)
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
save_method_accept_variant = "variant" in save_method_signature.parameters
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
save_kwargs = {}
if save_method_accept_safe:
save_kwargs["safe_serialization"] = safe_serialization
if save_method_accept_variant:
save_kwargs["variant"] = variant
if save_method_accept_max_shard_size and max_shard_size is not None:
save_kwargs["max_shard_size"] = max_shard_size
component_save_path = os.path.join(save_directory, component_name)
save_method(component_save_path, **save_kwargs)
if component_name not in self.config:
continue
has_no_load_id = not hasattr(component, "_diffusers_load_id") or component._diffusers_load_id == "null"
if overwrite_modular_index or has_no_load_id:
library, class_name, component_spec_dict = self.config[component_name]
component_spec_dict["pretrained_model_name_or_path"] = repo_id if push_to_hub else save_directory
component_spec_dict["subfolder"] = component_name
self.register_to_config(**{component_name: (library, class_name, component_spec_dict)})
self.save_config(save_directory=save_directory)
if push_to_hub:
card_content = generate_modular_model_card_content(self.blocks)
model_card = load_or_create_model_card(
repo_id,
token=token,
is_pipeline=True,
model_description=MODULAR_MODEL_CARD_TEMPLATE.format(**card_content),
is_modular=True,
update_model_card=update_model_card,
)
model_card = populate_model_card(model_card, tags=card_content["tags"])
model_card.save(os.path.join(save_directory, "README.md"))
# YiYi TODO: maybe order the json file to make it more readable: configs first, then components
self.save_config(save_directory=save_directory)
if push_to_hub:
self._upload_folder(
save_directory,
repo_id,
@@ -2131,8 +2268,9 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
```
Notes:
- Components with trained weights should be loaded with `AutoModel.from_pretrained()` or
`ComponentSpec.load()` so that loading specs are preserved for serialization.
- Components loaded with `AutoModel.from_pretrained()` or `ComponentSpec.load()` will have
loading specs preserved for serialization. Custom or locally loaded components without Hub references will
have their `modular_model_index.json` entries updated automatically during `save_pretrained()`.
- ConfigMixin objects without weights (e.g., schedulers, guiders) can be passed directly.
"""
@@ -2154,13 +2292,10 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
new_component_spec = current_component_spec
if hasattr(self, name) and getattr(self, name) is not None:
logger.warning(f"ModularPipeline.update_components: setting {name} to None (spec unchanged)")
elif current_component_spec.default_creation_method == "from_pretrained" and not (
hasattr(component, "_diffusers_load_id") and component._diffusers_load_id is not None
elif (
current_component_spec.default_creation_method == "from_pretrained"
and getattr(component, "_diffusers_load_id", None) is None
):
logger.warning(
f"ModularPipeline.update_components: {name} has no valid _diffusers_load_id. "
f"This will result in empty loading spec, use ComponentSpec.load() for proper specs"
)
new_component_spec = ComponentSpec(name=name, type_hint=type(component))
else:
new_component_spec = ComponentSpec.from_component(name, component)
@@ -2233,17 +2368,49 @@ class ModularPipeline(ConfigMixin, PushToHubMixin):
elif "default" in value:
# check if the default is specified
component_load_kwargs[key] = value["default"]
# Only pass trust_remote_code to components from the same repo as the pipeline.
# When a user passes trust_remote_code=True, they intend to trust code from the
# pipeline's repo, not from external repos referenced in modular_model_index.json.
trust_remote_code_stripped = False
if (
"trust_remote_code" in component_load_kwargs
and self._pretrained_model_name_or_path is not None
and spec.pretrained_model_name_or_path != self._pretrained_model_name_or_path
):
component_load_kwargs.pop("trust_remote_code")
trust_remote_code_stripped = True
if not spec.pretrained_model_name_or_path:
logger.info(f"Skipping component `{name}`: no pretrained model path specified.")
continue
try:
components_to_register[name] = spec.load(**component_load_kwargs)
except Exception:
logger.warning(
f"\nFailed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{traceback.format_exc()}"
)
tb = traceback.format_exc()
if trust_remote_code_stripped and "trust_remote_code" in tb:
warning_msg = (
f"Failed to load component `{name}` from external repository "
f"`{spec.pretrained_model_name_or_path}`.\n\n"
f"`trust_remote_code=True` was not forwarded to `{name}` because it comes from "
f"a different repository than the pipeline (`{self._pretrained_model_name_or_path}`). "
f"For safety, `trust_remote_code` is only forwarded to components from the same "
f"repository as the pipeline.\n\n"
f"You need to load this component manually with `trust_remote_code=True` and pass it "
f"to the pipeline via `pipe.update_components()`. For example, if it is a custom model:\n\n"
f' {name} = AutoModel.from_pretrained("{spec.pretrained_model_name_or_path}", trust_remote_code=True)\n'
f" pipe.update_components({name}={name})\n"
)
else:
warning_msg = (
f"Failed to create component {name}:\n"
f"- Component spec: {spec}\n"
f"- load() called with kwargs: {component_load_kwargs}\n"
"If this component is not required for your workflow you can safely ignore this message.\n\n"
"Traceback:\n"
f"{tb}"
)
logger.warning(warning_msg)
# Register all components at once
self.register_components(**components_to_register)

View File

@@ -22,10 +22,12 @@ from typing import Any, Literal, Type, Union, get_args, get_origin
import PIL.Image
import torch
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from ..configuration_utils import ConfigMixin, FrozenDict
from ..loaders.single_file_utils import _is_single_file_path_or_url
from ..utils import DIFFUSERS_LOAD_ID_FIELDS, is_torch_available, logging
from ..utils.import_utils import _is_package_available
if is_torch_available():
@@ -50,11 +52,7 @@ This modular pipeline is composed of the following blocks:
{components_description} {configs_section}
## Input/Output Specification
### Inputs {inputs_description}
### Outputs {outputs_description}
{io_specification_section}
"""
@@ -311,6 +309,12 @@ class ComponentSpec:
f"`type_hint` is required when loading a single file model but is missing for component: {self.name}"
)
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if self.type_hint is not None and not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
if self.type_hint is None:
try:
from diffusers import AutoModel
@@ -328,6 +332,12 @@ class ComponentSpec:
else getattr(self.type_hint, "from_pretrained")
)
# `torch_dtype` is not an accepted parameter for tokenizers and processors.
# As a result, it gets stored in `init_kwargs`, which are written to the config
# during save. This causes JSON serialization to fail when saving the component.
if not issubclass(self.type_hint, torch.nn.Module):
kwargs.pop("torch_dtype", None)
try:
component = load_method(pretrained_model_name_or_path, **load_kwargs, **kwargs)
except Exception as e:
@@ -799,6 +809,46 @@ def format_output_params(output_params, indent_level=4, max_line_length=115):
return format_params(output_params, "Outputs", indent_level, max_line_length)
def format_params_markdown(params, header="Inputs"):
"""Format a list of InputParam or OutputParam objects as a markdown bullet-point list.
Suitable for model cards rendered on Hugging Face Hub.
Args:
params: list of InputParam or OutputParam objects to format
header: Header text (e.g. "Inputs" or "Outputs")
Returns:
A formatted markdown string, or empty string if params is empty.
"""
if not params:
return ""
def get_type_str(type_hint):
if isinstance(type_hint, UnionType) or get_origin(type_hint) is Union:
type_strs = [t.__name__ if hasattr(t, "__name__") else str(t) for t in get_args(type_hint)]
return " | ".join(type_strs)
return type_hint.__name__ if hasattr(type_hint, "__name__") else str(type_hint)
lines = [f"**{header}:**\n"] if header else []
for param in params:
type_str = get_type_str(param.type_hint) if param.type_hint != Any else ""
name = f"**{param.kwargs_type}" if param.name is None and param.kwargs_type is not None else param.name
param_str = f"- `{name}` (`{type_str}`"
if hasattr(param, "required") and not param.required:
param_str += ", *optional*"
if param.default is not None:
param_str += f", defaults to `{param.default}`"
param_str += ")"
desc = param.description if param.description else "No description provided"
param_str += f": {desc}"
lines.append(param_str)
return "\n".join(lines)
def format_components(components, indent_level=4, max_line_length=115, add_empty_lines=True):
"""Format a list of ComponentSpec objects into a readable string representation.
@@ -972,6 +1022,89 @@ def make_doc_string(
return output
def _validate_requirements(reqs):
if reqs is None:
normalized_reqs = {}
else:
if not isinstance(reqs, dict):
raise ValueError(
"Requirements must be provided as a dictionary mapping package names to version specifiers."
)
normalized_reqs = _normalize_requirements(reqs)
if not normalized_reqs:
return {}
final: dict[str, str] = {}
for req, specified_ver in normalized_reqs.items():
req_available, req_actual_ver = _is_package_available(req)
if not req_available:
logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.")
if specified_ver:
try:
specifier = SpecifierSet(specified_ver)
except InvalidSpecifier as err:
raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err
if req_actual_ver == "N/A":
logger.warning(
f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected."
)
elif not specifier.contains(req_actual_ver, prereleases=True):
logger.warning(
f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected."
)
final[req] = specified_ver
return final
def _normalize_requirements(reqs):
if not reqs:
return {}
normalized: "OrderedDict[str, str]" = OrderedDict()
def _accumulate(mapping: dict[str, Any]):
for pkg, spec in mapping.items():
if isinstance(spec, dict):
# This is recursive because blocks are composable. This way, we can merge requirements
# from multiple blocks.
_accumulate(spec)
continue
pkg_name = str(pkg).strip()
if not pkg_name:
raise ValueError("Requirement package name cannot be empty.")
spec_str = "" if spec is None else str(spec).strip()
if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")):
spec_str = f"=={spec_str}"
existing_spec = normalized.get(pkg_name)
if existing_spec is not None:
if not existing_spec and spec_str:
normalized[pkg_name] = spec_str
elif existing_spec and spec_str and existing_spec != spec_str:
try:
combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str])))
except InvalidSpecifier:
logger.warning(
f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'."
)
else:
normalized[pkg_name] = str(combined_spec)
continue
normalized[pkg_name] = spec_str
_accumulate(reqs)
return normalized
def combine_inputs(*named_input_lists: list[tuple[str, list[InputParam]]]) -> list[InputParam]:
"""
Combines multiple lists of InputParam objects from different blocks. For duplicate inputs, updates only if current
@@ -1055,8 +1188,7 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
- blocks_description: Detailed architecture of blocks
- components_description: List of required components
- configs_section: Configuration parameters section
- inputs_description: Input parameters specification
- outputs_description: Output parameters specification
- io_specification_section: Input/Output specification (per-workflow or unified)
- trigger_inputs_section: Conditional execution information
- tags: List of relevant tags for the model card
"""
@@ -1075,15 +1207,6 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if block_desc:
blocks_desc_parts.append(f" - {block_desc}")
# add sub-blocks if any
if hasattr(block, "sub_blocks") and block.sub_blocks:
for sub_name, sub_block in block.sub_blocks.items():
sub_class = sub_block.__class__.__name__
sub_desc = sub_block.description.split("\n")[0] if getattr(sub_block, "description", "") else ""
blocks_desc_parts.append(f" - *{sub_name}*: `{sub_class}`")
if sub_desc:
blocks_desc_parts.append(f" - {sub_desc}")
blocks_description = "\n".join(blocks_desc_parts) if blocks_desc_parts else "No blocks defined."
components = getattr(blocks, "expected_components", [])
@@ -1109,63 +1232,76 @@ def generate_modular_model_card_content(blocks) -> dict[str, Any]:
if configs_description:
configs_section = f"\n\n## Configuration Parameters\n\n{configs_description}"
inputs = blocks.inputs
outputs = blocks.outputs
# Branch on whether workflows are defined
has_workflows = getattr(blocks, "_workflow_map", None) is not None
# format inputs as markdown list
inputs_parts = []
required_inputs = [inp for inp in inputs if inp.required]
optional_inputs = [inp for inp in inputs if not inp.required]
if has_workflows:
workflow_map = blocks._workflow_map
parts = []
if required_inputs:
inputs_parts.append("**Required:**\n")
for inp in required_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
inputs_parts.append(f"- `{inp.name}` (`{type_str}`): {desc}")
# If blocks overrides outputs (e.g. to return just "images" instead of all intermediates),
# use that as the shared output for all workflows
blocks_outputs = blocks.outputs
blocks_intermediate = getattr(blocks, "intermediate_outputs", None)
shared_outputs = (
blocks_outputs if blocks_intermediate is not None and blocks_outputs != blocks_intermediate else None
)
if optional_inputs:
if required_inputs:
inputs_parts.append("")
inputs_parts.append("**Optional:**\n")
for inp in optional_inputs:
if hasattr(inp.type_hint, "__name__"):
type_str = inp.type_hint.__name__
elif inp.type_hint is not None:
type_str = str(inp.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = inp.description or "No description provided"
default_str = f", default: `{inp.default}`" if inp.default is not None else ""
inputs_parts.append(f"- `{inp.name}` (`{type_str}`){default_str}: {desc}")
parts.append("## Workflow Input Specification\n")
inputs_description = "\n".join(inputs_parts) if inputs_parts else "No specific inputs defined."
# Per-workflow details: show trigger inputs with full param descriptions
for wf_name, trigger_inputs in workflow_map.items():
trigger_input_names = set(trigger_inputs.keys())
try:
workflow_blocks = blocks.get_workflow(wf_name)
except Exception:
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
parts.append("*Could not resolve workflow blocks.*\n")
parts.append("</details>\n")
continue
# format outputs as markdown list
outputs_parts = []
for out in outputs:
if hasattr(out.type_hint, "__name__"):
type_str = out.type_hint.__name__
elif out.type_hint is not None:
type_str = str(out.type_hint).replace("typing.", "")
else:
type_str = "Any"
desc = out.description or "No description provided"
outputs_parts.append(f"- `{out.name}` (`{type_str}`): {desc}")
wf_inputs = workflow_blocks.inputs
# Show only trigger inputs with full parameter descriptions
trigger_params = [p for p in wf_inputs if p.name in trigger_input_names]
outputs_description = "\n".join(outputs_parts) if outputs_parts else "Standard pipeline outputs."
parts.append(f"<details>\n<summary><strong>{wf_name}</strong></summary>\n")
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
inputs_str = format_params_markdown(trigger_params, header=None)
parts.append(inputs_str if inputs_str else "No additional inputs required.")
parts.append("")
parts.append("</details>\n")
# Common Inputs & Outputs section (like non-workflow pipelines)
all_inputs = blocks.inputs
all_outputs = shared_outputs if shared_outputs is not None else blocks.outputs
inputs_str = format_params_markdown(all_inputs, "Inputs")
outputs_str = format_params_markdown(all_outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
parts.append(f"\n## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}")
io_specification_section = "\n".join(parts)
# Suppress trigger_inputs_section when workflows are shown (it's redundant)
trigger_inputs_section = ""
else:
# Unified I/O section (original behavior)
inputs = blocks.inputs
outputs = blocks.outputs
inputs_str = format_params_markdown(inputs, "Inputs")
outputs_str = format_params_markdown(outputs, "Outputs")
inputs_description = inputs_str if inputs_str else "No specific inputs defined."
outputs_description = outputs_str if outputs_str else "Standard pipeline outputs."
io_specification_section = f"## Input/Output Specification\n\n{inputs_description}\n\n{outputs_description}"
trigger_inputs_section = ""
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
trigger_inputs_list = sorted([t for t in blocks.trigger_inputs if t is not None])
if trigger_inputs_list:
trigger_inputs_str = ", ".join(f"`{t}`" for t in trigger_inputs_list)
trigger_inputs_section = f"""
### Conditional Execution
This pipeline contains blocks that are selected at runtime based on inputs:
@@ -1178,7 +1314,18 @@ This pipeline contains blocks that are selected at runtime based on inputs:
if hasattr(blocks, "model_name") and blocks.model_name:
tags.append(blocks.model_name)
if hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
if has_workflows:
# Derive tags from workflow names
workflow_names = set(blocks._workflow_map.keys())
if any("inpainting" in wf for wf in workflow_names):
tags.append("inpainting")
if any("image2image" in wf for wf in workflow_names):
tags.append("image-to-image")
if any("controlnet" in wf for wf in workflow_names):
tags.append("controlnet")
if any("text2image" in wf for wf in workflow_names):
tags.append("text-to-image")
elif hasattr(blocks, "trigger_inputs") and blocks.trigger_inputs:
triggers = blocks.trigger_inputs
if any(t in triggers for t in ["mask", "mask_image"]):
tags.append("inpainting")
@@ -1206,8 +1353,7 @@ This pipeline uses a {block_count}-block architecture that can be customized and
"blocks_description": blocks_description,
"components_description": components_description,
"configs_section": configs_section,
"inputs_description": inputs_description,
"outputs_description": outputs_description,
"io_specification_section": io_specification_section,
"trigger_inputs_section": trigger_inputs_section,
"tags": tags,
}

View File

@@ -237,6 +237,7 @@ else:
"EasyAnimateInpaintPipeline",
"EasyAnimateControlPipeline",
]
_import_structure["helios"] = ["HeliosPipeline", "HeliosPyramidPipeline"]
_import_structure["hidream_image"] = ["HiDreamImagePipeline"]
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
_import_structure["hunyuan_video"] = [
@@ -667,6 +668,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hidream_image import HiDreamImagePipeline
from .hunyuan_image import HunyuanImagePipeline, HunyuanImageRefinerPipeline
from .hunyuan_video import (

View File

@@ -54,6 +54,7 @@ from .flux import (
)
from .flux2 import Flux2KleinPipeline, Flux2Pipeline
from .glm_image import GlmImagePipeline
from .helios import HeliosPipeline, HeliosPyramidPipeline
from .hunyuandit import HunyuanDiTPipeline
from .kandinsky import (
KandinskyCombinedPipeline,
@@ -174,6 +175,8 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
("cogview3", CogView3PlusPipeline),
("cogview4", CogView4Pipeline),
("glm_image", GlmImagePipeline),
("helios", HeliosPipeline),
("helios-pyramid", HeliosPyramidPipeline),
("cogview4-control", CogView4ControlPipeline),
("qwenimage", QwenImagePipeline),
("qwenimage-controlnet", QwenImageControlNetPipeline),

View File

@@ -0,0 +1,48 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_helios"] = ["HeliosPipeline"]
_import_structure["pipeline_helios_pyramid"] = ["HeliosPyramidPipeline"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .pipeline_helios import HeliosPipeline
from .pipeline_helios_pyramid import HeliosPyramidPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,916 @@
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
from typing import Any, Callable
import numpy as np
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...loaders import HeliosLoraLoaderMixin
from ...models import AutoencoderKLWan, HeliosTransformer3DModel
from ...schedulers import HeliosScheduler
from ...utils import is_ftfy_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HeliosPipelineOutput
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
else:
XLA_AVAILABLE = False
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
if is_ftfy_available():
import ftfy
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from diffusers.utils import export_to_video
>>> from diffusers import AutoencoderKLWan, HeliosPipeline
>>> # Available models: BestWishYsh/Helios-Base, BestWishYsh/Helios-Mid, BestWishYsh/Helios-Distilled
>>> model_id = "BestWishYsh/Helios-Base"
>>> vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
>>> pipe = HeliosPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
>>> pipe.to("cuda")
>>> prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
>>> negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
>>> output = pipe(
... prompt=prompt,
... negative_prompt=negative_prompt,
... height=384,
... width=640,
... num_frames=132,
... guidance_scale=5.0,
... ).frames[0]
>>> export_to_video(output, "output.mp4", fps=24)
```
"""
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
class HeliosPipeline(DiffusionPipeline, HeliosLoraLoaderMixin):
r"""
Pipeline for text-to-video / image-to-video / video-to-video generation using Helios.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
Args:
tokenizer ([`T5Tokenizer`]):
Tokenizer from [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5Tokenizer),
specifically the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
text_encoder ([`T5EncoderModel`]):
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
the [google/umt5-xxl](https://huggingface.co/google/umt5-xxl) variant.
transformer ([`HeliosTransformer3DModel`]):
Conditional Transformer to denoise the input latents.
scheduler ([`HeliosScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
vae ([`AutoencoderKLWan`]):
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
_optional_components = ["transformer"]
def __init__(
self,
tokenizer: AutoTokenizer,
text_encoder: UMT5EncoderModel,
vae: AutoencoderKLWan,
scheduler: HeliosScheduler,
transformer: HeliosTransformer3DModel,
):
super().__init__()
self.register_modules(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
transformer=transformer,
scheduler=scheduler,
)
self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
def _get_t5_prompt_embeds(
self,
prompt: str | list[str] = None,
num_videos_per_prompt: int = 1,
max_sequence_length: int = 226,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(u) for u in prompt]
batch_size = len(prompt)
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
seq_lens = mask.gt(0).sum(dim=1).long()
prompt_embeds = self.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
)
# duplicate text embeddings for each generation per prompt, using mps friendly method
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
return prompt_embeds, text_inputs.attention_mask.bool()
def encode_prompt(
self,
prompt: str | list[str],
negative_prompt: str | list[str] | None = None,
do_classifier_free_guidance: bool = True,
num_videos_per_prompt: int = 1,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
max_sequence_length: int = 226,
device: torch.device | None = None,
dtype: torch.dtype | None = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `list[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
Whether to use classifier free guidance or not.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
Number of videos that should be generated per prompt. torch device to place the resulting embeddings on
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
device: (`torch.device`, *optional*):
torch device
dtype: (`torch.dtype`, *optional*):
torch dtype
"""
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
if prompt_embeds is None:
prompt_embeds, _ = self._get_t5_prompt_embeds(
prompt=prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, _ = self._get_t5_prompt_embeds(
prompt=negative_prompt,
num_videos_per_prompt=num_videos_per_prompt,
max_sequence_length=max_sequence_length,
device=device,
dtype=dtype,
)
return prompt_embeds, negative_prompt_embeds
def check_inputs(
self,
prompt,
negative_prompt,
height,
width,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
image=None,
video=None,
):
if height % 16 != 0 or width % 16 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`: {negative_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif negative_prompt is not None and (
not isinstance(negative_prompt, str) and not isinstance(negative_prompt, list)
):
raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
if image is not None and video is not None:
raise ValueError("image and video cannot be provided simultaneously")
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int = 16,
height: int = 384,
width: int = 640,
num_frames: int = 33,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
if latents is not None:
return latents.to(device=device, dtype=dtype)
num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
shape = (
batch_size,
num_channels_latents,
num_latent_frames,
int(height) // self.vae_scale_factor_spatial,
int(width) // self.vae_scale_factor_spatial,
)
if isinstance(generator, list) and len(generator) != batch_size:
raise ValueError(
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def prepare_image_latents(
self,
image: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
num_latent_frames_per_chunk: int,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
fake_latents: torch.Tensor | None = None,
) -> torch.Tensor:
device = device or self._execution_device
if latents is None:
image = image.unsqueeze(2).to(device=device, dtype=self.vae.dtype)
latents = self.vae.encode(image).latent_dist.sample(generator=generator)
latents = (latents - latents_mean) * latents_std
if fake_latents is None:
min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
fake_video = image.repeat(1, 1, min_frames, 1, 1).to(device=device, dtype=self.vae.dtype)
fake_latents_full = self.vae.encode(fake_video).latent_dist.sample(generator=generator)
fake_latents_full = (fake_latents_full - latents_mean) * latents_std
fake_latents = fake_latents_full[:, :, -1:, :, :]
return latents.to(device=device, dtype=dtype), fake_latents.to(device=device, dtype=dtype)
def prepare_video_latents(
self,
video: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
num_latent_frames_per_chunk: int,
dtype: torch.dtype | None = None,
device: torch.device | None = None,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
) -> torch.Tensor:
device = device or self._execution_device
video = video.to(device=device, dtype=self.vae.dtype)
if latents is None:
num_frames = video.shape[2]
min_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
num_chunks = num_frames // min_frames
if num_chunks == 0:
raise ValueError(
f"Video must have at least {min_frames} frames "
f"(got {num_frames} frames). "
f"Required: (num_latent_frames_per_chunk - 1) * {self.vae_scale_factor_temporal} + 1 = ({num_latent_frames_per_chunk} - 1) * {self.vae_scale_factor_temporal} + 1 = {min_frames}"
)
total_valid_frames = num_chunks * min_frames
start_frame = num_frames - total_valid_frames
first_frame = video[:, :, 0:1, :, :]
first_frame_latent = self.vae.encode(first_frame).latent_dist.sample(generator=generator)
first_frame_latent = (first_frame_latent - latents_mean) * latents_std
latents_chunks = []
for i in range(num_chunks):
chunk_start = start_frame + i * min_frames
chunk_end = chunk_start + min_frames
video_chunk = video[:, :, chunk_start:chunk_end, :, :]
chunk_latents = self.vae.encode(video_chunk).latent_dist.sample(generator=generator)
chunk_latents = (chunk_latents - latents_mean) * latents_std
latents_chunks.append(chunk_latents)
latents = torch.cat(latents_chunks, dim=2)
return first_frame_latent.to(device=device, dtype=dtype), latents.to(device=device, dtype=dtype)
@property
def guidance_scale(self):
return self._guidance_scale
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@property
def current_timestep(self):
return self._current_timestep
@property
def interrupt(self):
return self._interrupt
@property
def attention_kwargs(self):
return self._attention_kwargs
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] = None,
negative_prompt: str | list[str] = None,
height: int = 384,
width: int = 640,
num_frames: int = 132,
num_inference_steps: int = 50,
sigmas: list[float] = None,
guidance_scale: float = 5.0,
num_videos_per_prompt: int | None = 1,
generator: torch.Generator | list[torch.Generator] | None = None,
latents: torch.Tensor | None = None,
prompt_embeds: torch.Tensor | None = None,
negative_prompt_embeds: torch.Tensor | None = None,
output_type: str | None = "np",
return_dict: bool = True,
attention_kwargs: dict[str, Any] | None = None,
callback_on_step_end: Callable[[int, int], None] | PipelineCallback | MultiPipelineCallbacks | None = None,
callback_on_step_end_tensor_inputs: list[str] = ["latents"],
max_sequence_length: int = 512,
# ------------ I2V ------------
image: PipelineImageInput | None = None,
image_latents: torch.Tensor | None = None,
fake_image_latents: torch.Tensor | None = None,
add_noise_to_image_latents: bool = True,
image_noise_sigma_min: float = 0.111,
image_noise_sigma_max: float = 0.135,
# ------------ V2V ------------
video: PipelineImageInput | None = None,
video_latents: torch.Tensor | None = None,
add_noise_to_video_latents: bool = True,
video_noise_sigma_min: float = 0.111,
video_noise_sigma_max: float = 0.135,
# ------------ Stage 1 ------------
history_sizes: list = [16, 2, 1],
num_latent_frames_per_chunk: int = 9,
keep_first_frame: bool = True,
is_skip_first_chunk: bool = False,
):
r"""
The call function to the pipeline for generation.
Args:
prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, pass `prompt_embeds` instead.
negative_prompt (`str` or `list[str]`, *optional*):
The prompt or prompts to avoid during image generation. If not defined, pass `negative_prompt_embeds`
instead. Ignored when not using guidance (`guidance_scale` < `1`).
height (`int`, defaults to `384`):
The height in pixels of the generated image.
width (`int`, defaults to `640`):
The width in pixels of the generated image.
num_frames (`int`, defaults to `132`):
The number of frames in the generated video.
num_inference_steps (`int`, defaults to `50`):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, defaults to `5.0`):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
num_videos_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `list[torch.Generator]`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor is generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
provided, text embeddings are generated from the `prompt` input argument.
output_type (`str`, *optional*, defaults to `"np"`):
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`HeliosPipelineOutput`] instead of a plain tuple.
attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`list`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int`, defaults to `512`):
The maximum sequence length of the text encoder. If the prompt is longer than this, it will be
truncated. If the prompt is shorter, it will be padded to this length.
Examples:
Returns:
[`~HeliosPipelineOutput`] or `tuple`:
If `return_dict` is `True`, [`HeliosPipelineOutput`] is returned, otherwise a `tuple` is returned where
the first element is a list with the generated images and the second element is a list of `bool`s
indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content.
"""
history_sizes = sorted(history_sizes, reverse=True) # From big to small
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
negative_prompt,
height,
width,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
image,
video,
)
num_frames = max(num_frames, 1)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._current_timestep = None
self._interrupt = False
device = self._execution_device
vae_dtype = self.vae.dtype
latents_mean = (
torch.tensor(self.vae.config.latents_mean)
.view(1, self.vae.config.z_dim, 1, 1, 1)
.to(device, self.vae.dtype)
)
latents_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to(
device, self.vae.dtype
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# 3. Encode input prompt
prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=self.do_classifier_free_guidance,
num_videos_per_prompt=num_videos_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
device=device,
)
transformer_dtype = self.transformer.dtype
prompt_embeds = prompt_embeds.to(transformer_dtype)
if negative_prompt_embeds is not None:
negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
# 4. Prepare image or video
if image is not None:
image = self.video_processor.preprocess(image, height=height, width=width)
image_latents, fake_image_latents = self.prepare_image_latents(
image,
latents_mean=latents_mean,
latents_std=latents_std,
num_latent_frames_per_chunk=num_latent_frames_per_chunk,
dtype=torch.float32,
device=device,
generator=generator,
latents=image_latents,
fake_latents=fake_image_latents,
)
if image_latents is not None and add_noise_to_image_latents:
image_noise_sigma = (
torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ image_noise_sigma_min
)
image_latents = (
image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ (1 - image_noise_sigma) * image_latents
)
fake_image_noise_sigma = (
torch.rand(1, device=device, generator=generator) * (video_noise_sigma_max - video_noise_sigma_min)
+ video_noise_sigma_min
)
fake_image_latents = (
fake_image_noise_sigma * randn_tensor(fake_image_latents.shape, generator=generator, device=device)
+ (1 - fake_image_noise_sigma) * fake_image_latents
)
if video is not None:
video = self.video_processor.preprocess_video(video, height=height, width=width)
image_latents, video_latents = self.prepare_video_latents(
video,
latents_mean=latents_mean,
latents_std=latents_std,
num_latent_frames_per_chunk=num_latent_frames_per_chunk,
dtype=torch.float32,
device=device,
generator=generator,
latents=video_latents,
)
if video_latents is not None and add_noise_to_video_latents:
image_noise_sigma = (
torch.rand(1, device=device, generator=generator) * (image_noise_sigma_max - image_noise_sigma_min)
+ image_noise_sigma_min
)
image_latents = (
image_noise_sigma * randn_tensor(image_latents.shape, generator=generator, device=device)
+ (1 - image_noise_sigma) * image_latents
)
noisy_latents_chunks = []
num_latent_chunks = video_latents.shape[2] // num_latent_frames_per_chunk
for i in range(num_latent_chunks):
chunk_start = i * num_latent_frames_per_chunk
chunk_end = chunk_start + num_latent_frames_per_chunk
latent_chunk = video_latents[:, :, chunk_start:chunk_end, :, :]
chunk_frames = latent_chunk.shape[2]
frame_sigmas = (
torch.rand(chunk_frames, device=device, generator=generator)
* (video_noise_sigma_max - video_noise_sigma_min)
+ video_noise_sigma_min
)
frame_sigmas = frame_sigmas.view(1, 1, chunk_frames, 1, 1)
noisy_chunk = (
frame_sigmas * randn_tensor(latent_chunk.shape, generator=generator, device=device)
+ (1 - frame_sigmas) * latent_chunk
)
noisy_latents_chunks.append(noisy_chunk)
video_latents = torch.cat(noisy_latents_chunks, dim=2)
# 5. Prepare latent variables
num_channels_latents = self.transformer.config.in_channels
window_num_frames = (num_latent_frames_per_chunk - 1) * self.vae_scale_factor_temporal + 1
num_latent_chunk = max(1, (num_frames + window_num_frames - 1) // window_num_frames)
num_history_latent_frames = sum(history_sizes)
history_video = None
total_generated_latent_frames = 0
if not keep_first_frame:
history_sizes[-1] = history_sizes[-1] + 1
history_latents = torch.zeros(
batch_size,
num_channels_latents,
num_history_latent_frames,
height // self.vae_scale_factor_spatial,
width // self.vae_scale_factor_spatial,
device=device,
dtype=torch.float32,
)
if fake_image_latents is not None:
history_latents = torch.cat([history_latents[:, :, :-1, :, :], fake_image_latents], dim=2)
total_generated_latent_frames += 1
if video_latents is not None:
history_frames = history_latents.shape[2]
video_frames = video_latents.shape[2]
if video_frames < history_frames:
keep_frames = history_frames - video_frames
history_latents = torch.cat([history_latents[:, :, :keep_frames, :, :], video_latents], dim=2)
else:
history_latents = video_latents
total_generated_latent_frames += video_latents.shape[2]
if keep_first_frame:
indices = torch.arange(0, sum([1, *history_sizes, num_latent_frames_per_chunk]))
(
indices_prefix,
indices_latents_history_long,
indices_latents_history_mid,
indices_latents_history_1x,
indices_hidden_states,
) = indices.split([1, *history_sizes, num_latent_frames_per_chunk], dim=0)
indices_latents_history_short = torch.cat([indices_prefix, indices_latents_history_1x], dim=0)
else:
indices = torch.arange(0, sum([*history_sizes, num_latent_frames_per_chunk]))
(
indices_latents_history_long,
indices_latents_history_mid,
indices_latents_history_short,
indices_hidden_states,
) = indices.split([*history_sizes, num_latent_frames_per_chunk], dim=0)
indices_hidden_states = indices_hidden_states.unsqueeze(0)
indices_latents_history_short = indices_latents_history_short.unsqueeze(0)
indices_latents_history_mid = indices_latents_history_mid.unsqueeze(0)
indices_latents_history_long = indices_latents_history_long.unsqueeze(0)
# 6. Denoising loop
patch_size = self.transformer.config.patch_size
image_seq_len = (
num_latent_frames_per_chunk
* (height // self.vae_scale_factor_spatial)
* (width // self.vae_scale_factor_spatial)
// (patch_size[0] * patch_size[1] * patch_size[2])
)
sigmas = np.linspace(0.999, 0.0, num_inference_steps + 1)[:-1] if sigmas is None else sigmas
mu = calculate_shift(
image_seq_len,
self.scheduler.config.get("base_image_seq_len", 256),
self.scheduler.config.get("max_image_seq_len", 4096),
self.scheduler.config.get("base_shift", 0.5),
self.scheduler.config.get("max_shift", 1.15),
)
for k in range(num_latent_chunk):
is_first_chunk = k == 0
is_second_chunk = k == 1
if keep_first_frame:
latents_history_long, latents_history_mid, latents_history_1x = history_latents[
:, :, -num_history_latent_frames:
].split(history_sizes, dim=2)
if image_latents is None and is_first_chunk:
latents_prefix = torch.zeros(
(
batch_size,
num_channels_latents,
1,
latents_history_1x.shape[-2],
latents_history_1x.shape[-1],
),
device=device,
dtype=latents_history_1x.dtype,
)
else:
latents_prefix = image_latents
latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2)
else:
latents_history_long, latents_history_mid, latents_history_short = history_latents[
:, :, -num_history_latent_frames:
].split(history_sizes, dim=2)
latents = self.prepare_latents(
batch_size,
num_channels_latents,
height,
width,
window_num_frames,
dtype=torch.float32,
device=device,
generator=generator,
latents=None,
)
self.scheduler.set_timesteps(num_inference_steps, device=device, sigmas=sigmas, mu=mu)
timesteps = self.scheduler.timesteps
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
self._current_timestep = t
timestep = t.expand(latents.shape[0])
latent_model_input = latents.to(transformer_dtype)
latents_history_short = latents_history_short.to(transformer_dtype)
latents_history_mid = latents_history_mid.to(transformer_dtype)
latents_history_long = latents_history_long.to(transformer_dtype)
with self.transformer.cache_context("cond"):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=prompt_embeds,
indices_hidden_states=indices_hidden_states,
indices_latents_history_short=indices_latents_history_short,
indices_latents_history_mid=indices_latents_history_mid,
indices_latents_history_long=indices_latents_history_long,
latents_history_short=latents_history_short,
latents_history_mid=latents_history_mid,
latents_history_long=latents_history_long,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
if self.do_classifier_free_guidance:
with self.transformer.cache_context("uncond"):
noise_uncond = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=negative_prompt_embeds,
indices_hidden_states=indices_hidden_states,
indices_latents_history_short=indices_latents_history_short,
indices_latents_history_mid=indices_latents_history_mid,
indices_latents_history_long=indices_latents_history_long,
latents_history_short=latents_history_short,
latents_history_mid=latents_history_mid,
latents_history_long=latents_history_long,
attention_kwargs=attention_kwargs,
return_dict=False,
)[0]
noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
latents = self.scheduler.step(
noise_pred,
t,
latents,
generator=generator,
return_dict=False,
)[0]
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if XLA_AVAILABLE:
xm.mark_step()
if keep_first_frame and (
(is_first_chunk and image_latents is None) or (is_skip_first_chunk and is_second_chunk)
):
image_latents = latents[:, :, 0:1, :, :]
total_generated_latent_frames += latents.shape[2]
history_latents = torch.cat([history_latents, latents], dim=2)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:]
current_latents = (
real_history_latents[:, :, -num_latent_frames_per_chunk:].to(vae_dtype) / latents_std
+ latents_mean
)
current_video = self.vae.decode(current_latents, return_dict=False)[0]
if history_video is None:
history_video = current_video
else:
history_video = torch.cat([history_video, current_video], dim=2)
self._current_timestep = None
if output_type != "latent":
generated_frames = history_video.size(2)
generated_frames = (
generated_frames - 1
) // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
history_video = history_video[:, :, :generated_frames]
video = self.video_processor.postprocess_video(history_video, output_type=output_type)
else:
video = real_history_latents
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (video,)
return HeliosPipelineOutput(frames=video)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
from dataclasses import dataclass
import torch
from diffusers.utils import BaseOutput
@dataclass
class HeliosPipelineOutput(BaseOutput):
r"""
Output class for Helios pipelines.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
`(batch_size, num_frames, channels, height, width)`.
"""
frames: torch.Tensor

View File

@@ -19,7 +19,7 @@ import torch
from transformers import AutoTokenizer, PreTrainedModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets import ZImageControlNetModel
from ...models.transformers import ZImageTransformer2DModel
@@ -185,7 +185,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
class ZImageControlNetPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
@@ -365,7 +365,7 @@ class ZImageControlNetPipeline(DiffusionPipeline, FromSingleFileMixin):
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -20,7 +20,7 @@ import torch.nn.functional as F
from transformers import AutoTokenizer, PreTrainedModel
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin
from ...loaders import FromSingleFileMixin, ZImageLoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.controlnets import ZImageControlNetModel
from ...models.transformers import ZImageTransformer2DModel
@@ -185,7 +185,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
class ZImageControlNetInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFileMixin):
model_cpu_offload_seq = "text_encoder->transformer->vae"
_optional_components = []
_callback_tensor_inputs = ["latents", "prompt_embeds"]
@@ -372,7 +372,7 @@ class ZImageControlNetInpaintPipeline(DiffusionPipeline, FromSingleFileMixin):
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -347,7 +347,7 @@ class ZImageImg2ImgPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingle
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -462,7 +462,7 @@ class ZImageInpaintPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingle
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -339,7 +339,7 @@ class ZImageOmniPipeline(DiffusionPipeline, ZImageLoraLoaderMixin, FromSingleFil
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1
return self._guidance_scale > 0
@property
def joint_attention_kwargs(self):

View File

@@ -61,6 +61,8 @@ else:
_import_structure["scheduling_flow_match_euler_discrete"] = ["FlowMatchEulerDiscreteScheduler"]
_import_structure["scheduling_flow_match_heun_discrete"] = ["FlowMatchHeunDiscreteScheduler"]
_import_structure["scheduling_flow_match_lcm"] = ["FlowMatchLCMScheduler"]
_import_structure["scheduling_helios"] = ["HeliosScheduler"]
_import_structure["scheduling_helios_dmd"] = ["HeliosDMDScheduler"]
_import_structure["scheduling_heun_discrete"] = ["HeunDiscreteScheduler"]
_import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
_import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
@@ -164,6 +166,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
from .scheduling_flow_match_heun_discrete import FlowMatchHeunDiscreteScheduler
from .scheduling_flow_match_lcm import FlowMatchLCMScheduler
from .scheduling_helios import HeliosScheduler
from .scheduling_helios_dmd import HeliosDMDScheduler
from .scheduling_heun_discrete import HeunDiscreteScheduler
from .scheduling_ipndm import IPNDMScheduler
from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler

View File

@@ -0,0 +1,867 @@
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers.scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, deprecate
@dataclass
class HeliosSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
model_outputs: torch.FloatTensor | None = None
last_sample: torch.FloatTensor | None = None
this_order: int | None = None
class HeliosScheduler(SchedulerMixin, ConfigMixin):
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0, # Following Stable diffusion 3,
stages: int = 3,
stage_range: list = [0, 1 / 3, 2 / 3, 1],
gamma: float = 1 / 3,
# For UniPC
thresholding: bool = False,
prediction_type: str = "flow_prediction",
solver_order: int = 2,
predict_x0: bool = True,
solver_type: str = "bh2",
lower_order_final: bool = True,
disable_corrector: list[int] = [],
solver_p: SchedulerMixin = None,
use_flow_sigmas: bool = True,
scheduler_type: str = "unipc", # ["euler", "unipc"]
use_dynamic_shifting: bool = False,
time_shift_type: Literal["exponential", "linear"] = "exponential",
):
self.timestep_ratios = {} # The timestep ratio for each stage
self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
self.sigmas_per_stage = {} # always uniform [1000, 0]
self.start_sigmas = {} # for start point / upsample renoise
self.end_sigmas = {} # for end point
self.ori_start_sigmas = {}
# self.init_sigmas()
self.init_sigmas_for_each_stage()
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.gamma = gamma
if solver_type not in ["bh1", "bh2"]:
if solver_type in ["midpoint", "heun", "logrho"]:
self.register_to_config(solver_type="bh2")
else:
raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}")
self.predict_x0 = predict_x0
self.model_outputs = [None] * solver_order
self.timestep_list = [None] * solver_order
self.lower_order_nums = 0
self.disable_corrector = disable_corrector
self.solver_p = solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
def init_sigmas(self):
"""
initialize the global timesteps and sigmas
"""
num_train_timesteps = self.config.num_train_timesteps
shift = self.config.shift
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
sigmas = torch.from_numpy(sigmas)
timesteps = (sigmas * num_train_timesteps).clone()
self._step_index = None
self._begin_index = None
self.timesteps = timesteps
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def init_sigmas_for_each_stage(self):
"""
Init the timesteps for each stage
"""
self.init_sigmas()
stage_distance = []
stages = self.config.stages
training_steps = self.config.num_train_timesteps
stage_range = self.config.stage_range
# Init the start and end point of each stage
for i_s in range(stages):
# To decide the start and ends point
start_indice = int(stage_range[i_s] * training_steps)
start_indice = max(start_indice, 0)
end_indice = int(stage_range[i_s + 1] * training_steps)
end_indice = min(end_indice, training_steps)
start_sigma = self.sigmas[start_indice].item()
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
self.ori_start_sigmas[i_s] = start_sigma
if i_s != 0:
ori_sigma = 1 - start_sigma
gamma = self.config.gamma
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
start_sigma = 1 - corrected_sigma
stage_distance.append(start_sigma - end_sigma)
self.start_sigmas[i_s] = start_sigma
self.end_sigmas[i_s] = end_sigma
# Determine the ratio of each stage according to flow length
tot_distance = sum(stage_distance)
for i_s in range(stages):
if i_s == 0:
start_ratio = 0.0
else:
start_ratio = sum(stage_distance[:i_s]) / tot_distance
if i_s == stages - 1:
end_ratio = 0.9999999999999999
else:
end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
# Determine the timesteps and sigmas for each stage
for i_s in range(stages):
timestep_ratio = self.timestep_ratios[i_s]
# timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
self.timesteps_per_stage[i_s] = (
timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
)
stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
stage_index: int | None = None,
device: str | torch.device = None,
sigmas: bool | None = None,
mu: bool | None = None,
is_amplify_first_chunk: bool = False,
):
"""
Setting the timesteps and sigmas for each stage
"""
if self.config.scheduler_type == "dmd":
if is_amplify_first_chunk:
num_inference_steps = num_inference_steps * 2 + 1
else:
num_inference_steps = num_inference_steps + 1
self.num_inference_steps = num_inference_steps
self.init_sigmas()
if self.config.stages == 1:
if sigmas is None:
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
np.float32
)
if self.config.shift != 1.0:
assert not self.config.use_dynamic_shifting
sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = torch.from_numpy(sigmas)
else:
stage_timesteps = self.timesteps_per_stage[stage_index]
timesteps = np.linspace(
stage_timesteps[0].item(),
stage_timesteps[-1].item(),
num_inference_steps,
)
stage_sigmas = self.sigmas_per_stage[stage_index]
ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
sigmas = torch.from_numpy(ratios)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
self._step_index = None
self.reset_scheduler_history()
if self.config.scheduler_type == "dmd":
self.timesteps = self.timesteps[:-1]
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
if self.config.use_dynamic_shifting:
assert self.config.shift == 1.0
self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
if self.config.stages == 1:
self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
else:
self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
def _time_shift_exponential(self, mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
def _time_shift_linear(self, mu, sigma, t):
return mu / (mu + (1 / t - 1) ** sigma)
# ---------------------------------- Euler ----------------------------------
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
def step_euler(
self,
model_output: torch.FloatTensor,
timestep: float | torch.FloatTensor = None,
sample: torch.FloatTensor = None,
generator: torch.Generator | None = None,
sigma: torch.FloatTensor | None = None,
sigma_next: torch.FloatTensor | None = None,
return_dict: bool = True,
) -> HeliosSchedulerOutput | tuple:
assert (sigma is None) == (sigma_next is None), "sigma and sigma_next must both be None or both be not None"
if sigma is None and sigma_next is None:
if (
isinstance(timestep, int)
or isinstance(timestep, torch.IntTensor)
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if self.step_index is None:
self._step_index = 0
# Upcast to avoid precision issues when computing prev_sample
sample = sample.to(torch.float32)
if sigma is None and sigma_next is None:
sigma = self.sigmas[self.step_index]
sigma_next = self.sigmas[self.step_index + 1]
prev_sample = sample + (sigma_next - sigma) * model_output
# Cast sample back to model compatible dtype
prev_sample = prev_sample.to(model_output.dtype)
# upon completion increase step index by one
self._step_index += 1
if not return_dict:
return (prev_sample,)
return HeliosSchedulerOutput(prev_sample=prev_sample)
# ---------------------------------- UniPC ----------------------------------
def _sigma_to_alpha_sigma_t(self, sigma):
if self.config.use_flow_sigmas:
alpha_t = 1 - sigma
sigma_t = torch.clamp(sigma, min=1e-8)
else:
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
sigma_t = sigma * alpha_t
return alpha_t, sigma_t
def convert_model_output(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
sigma: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
r"""
Convert the model output to the corresponding type the UniPC algorithm needs.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model.
timestep (`int`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
Returns:
`torch.Tensor`:
The converted model output.
"""
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyword argument")
if timestep is not None:
deprecate(
"timesteps",
"1.0.0",
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
flag = False
if sigma is None:
flag = True
sigma = self.sigmas[self.step_index]
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
if self.predict_x0:
if self.config.prediction_type == "epsilon":
x0_pred = (sample - sigma_t * model_output) / alpha_t
elif self.config.prediction_type == "sample":
x0_pred = model_output
elif self.config.prediction_type == "v_prediction":
x0_pred = alpha_t * sample - sigma_t * model_output
elif self.config.prediction_type == "flow_prediction":
if flag:
sigma_t = self.sigmas[self.step_index]
else:
sigma_t = sigma
x0_pred = sample - sigma_t * model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, "
"`v_prediction`, or `flow_prediction` for the UniPCMultistepScheduler."
)
if self.config.thresholding:
x0_pred = self._threshold_sample(x0_pred)
return x0_pred
else:
if self.config.prediction_type == "epsilon":
return model_output
elif self.config.prediction_type == "sample":
epsilon = (sample - alpha_t * model_output) / sigma_t
return epsilon
elif self.config.prediction_type == "v_prediction":
epsilon = alpha_t * model_output + sigma_t * sample
return epsilon
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
" `v_prediction` for the UniPCMultistepScheduler."
)
def multistep_uni_p_bh_update(
self,
model_output: torch.Tensor,
*args,
sample: torch.Tensor = None,
order: int = None,
sigma: torch.Tensor = None,
sigma_next: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the UniP (B(h) version). Alternatively, `self.solver_p` is used if is specified.
Args:
model_output (`torch.Tensor`):
The direct output from the learned diffusion model at the current timestep.
prev_timestep (`int`):
The previous discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
order (`int`):
The order of UniP at this timestep (corresponds to the *p* in UniPC-p).
Returns:
`torch.Tensor`:
The sample tensor at the previous timestep.
"""
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
if sample is None:
if len(args) > 1:
sample = args[1]
else:
raise ValueError("missing `sample` as a required keyword argument")
if order is None:
if len(args) > 2:
order = args[2]
else:
raise ValueError("missing `order` as a required keyword argument")
if prev_timestep is not None:
deprecate(
"prev_timestep",
"1.0.0",
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
s0 = self.timestep_list[-1]
m0 = model_output_list[-1]
x = sample
if self.solver_p:
x_t = self.solver_p.step(model_output, s0, x).prev_sample
return x_t
if sigma_next is None and sigma is None:
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
else:
sigma_t, sigma_s0 = sigma_next, sigma
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - i
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1) # (B, K)
# for order 2, we use a simplified version
if order == 2:
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1]).to(device).to(x.dtype)
else:
D1s = None
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - alpha_t * B_h * pred_res
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, D1s)
else:
pred_res = 0
x_t = x_t_ - sigma_t * B_h * pred_res
x_t = x_t.to(x.dtype)
return x_t
def multistep_uni_c_bh_update(
self,
this_model_output: torch.Tensor,
*args,
last_sample: torch.Tensor = None,
this_sample: torch.Tensor = None,
order: int = None,
sigma_before: torch.Tensor = None,
sigma: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
"""
One step for the UniC (B(h) version).
Args:
this_model_output (`torch.Tensor`):
The model outputs at `x_t`.
this_timestep (`int`):
The current timestep `t`.
last_sample (`torch.Tensor`):
The generated sample before the last predictor `x_{t-1}`.
this_sample (`torch.Tensor`):
The generated sample after the last predictor `x_{t}`.
order (`int`):
The `p` of UniC-p at this step. The effective order of accuracy should be `order + 1`.
Returns:
`torch.Tensor`:
The corrected sample tensor at the current timestep.
"""
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
if last_sample is None:
if len(args) > 1:
last_sample = args[1]
else:
raise ValueError("missing `last_sample` as a required keyword argument")
if this_sample is None:
if len(args) > 2:
this_sample = args[2]
else:
raise ValueError("missing `this_sample` as a required keyword argument")
if order is None:
if len(args) > 3:
order = args[3]
else:
raise ValueError("missing `order` as a required keyword argument")
if this_timestep is not None:
deprecate(
"this_timestep",
"1.0.0",
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
)
model_output_list = self.model_outputs
m0 = model_output_list[-1]
x = last_sample
x_t = this_sample
model_t = this_model_output
if sigma_before is None and sigma is None:
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
else:
sigma_t, sigma_s0 = sigma, sigma_before
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
h = lambda_t - lambda_s0
device = this_sample.device
rks = []
D1s = []
for i in range(1, order):
si = self.step_index - (i + 1)
mi = model_output_list[-(i + 1)]
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
rk = (lambda_si - lambda_s0) / h
rks.append(rk)
D1s.append((mi - m0) / rk)
rks.append(1.0)
rks = torch.tensor(rks, device=device)
R = []
b = []
hh = -h if self.predict_x0 else h
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
h_phi_k = h_phi_1 / hh - 1
factorial_i = 1
if self.config.solver_type == "bh1":
B_h = hh
elif self.config.solver_type == "bh2":
B_h = torch.expm1(hh)
else:
raise NotImplementedError()
for i in range(1, order + 1):
R.append(torch.pow(rks, i - 1))
b.append(h_phi_k * factorial_i / B_h)
factorial_i *= i + 1
h_phi_k = h_phi_k / hh - 1 / factorial_i
R = torch.stack(R)
b = torch.tensor(b, device=device)
if len(D1s) > 0:
D1s = torch.stack(D1s, dim=1)
else:
D1s = None
# for order 1, we use a simplified version
if order == 1:
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
else:
rhos_c = torch.linalg.solve(R, b).to(device).to(x.dtype)
if self.predict_x0:
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - alpha_t * B_h * (corr_res + rhos_c[-1] * D1_t)
else:
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
if D1s is not None:
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], D1s)
else:
corr_res = 0
D1_t = model_t - m0
x_t = x_t_ - sigma_t * B_h * (corr_res + rhos_c[-1] * D1_t)
x_t = x_t.to(x.dtype)
return x_t
def step_unipc(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor = None,
sample: torch.Tensor = None,
return_dict: bool = True,
model_outputs: list = None,
timestep_list: list = None,
sigma_before: torch.Tensor = None,
sigma: torch.Tensor = None,
sigma_next: torch.Tensor = None,
cus_step_index: int = None,
cus_lower_order_num: int = None,
cus_this_order: int = None,
cus_last_sample: torch.Tensor = None,
) -> HeliosSchedulerOutput | tuple:
if self.num_inference_steps is None:
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if cus_step_index is None:
if self.step_index is None:
self._step_index = 0
else:
self._step_index = cus_step_index
if cus_lower_order_num is not None:
self.lower_order_nums = cus_lower_order_num
if cus_this_order is not None:
self.this_order = cus_this_order
if cus_last_sample is not None:
self.last_sample = cus_last_sample
use_corrector = (
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
)
# Convert model output using the proper conversion method
model_output_convert = self.convert_model_output(model_output, sample=sample, sigma=sigma)
if model_outputs is not None and timestep_list is not None:
self.model_outputs = model_outputs[:-1]
self.timestep_list = timestep_list[:-1]
if use_corrector:
sample = self.multistep_uni_c_bh_update(
this_model_output=model_output_convert,
last_sample=self.last_sample,
this_sample=sample,
order=self.this_order,
sigma_before=sigma_before,
sigma=sigma,
)
if model_outputs is not None and timestep_list is not None:
model_outputs[-1] = model_output_convert
self.model_outputs = model_outputs[1:]
self.timestep_list = timestep_list[1:]
else:
for i in range(self.config.solver_order - 1):
self.model_outputs[i] = self.model_outputs[i + 1]
self.timestep_list[i] = self.timestep_list[i + 1]
self.model_outputs[-1] = model_output_convert
self.timestep_list[-1] = timestep
if self.config.lower_order_final:
this_order = min(self.config.solver_order, len(self.timesteps) - self.step_index)
else:
this_order = self.config.solver_order
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
assert self.this_order > 0
self.last_sample = sample
prev_sample = self.multistep_uni_p_bh_update(
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
sample=sample,
order=self.this_order,
sigma=sigma,
sigma_next=sigma_next,
)
if cus_lower_order_num is None:
if self.lower_order_nums < self.config.solver_order:
self.lower_order_nums += 1
# upon completion increase step index by one
if cus_step_index is None:
self._step_index += 1
if not return_dict:
return (prev_sample, model_outputs, self.last_sample, self.this_order)
return HeliosSchedulerOutput(
prev_sample=prev_sample,
model_outputs=model_outputs,
last_sample=self.last_sample,
this_order=self.this_order,
)
# ---------------------------------- Merge ----------------------------------
def step(
self,
model_output: torch.FloatTensor,
timestep: float | torch.FloatTensor = None,
sample: torch.FloatTensor = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
) -> HeliosSchedulerOutput | tuple:
if self.config.scheduler_type == "euler":
return self.step_euler(
model_output=model_output,
timestep=timestep,
sample=sample,
generator=generator,
return_dict=return_dict,
)
elif self.config.scheduler_type == "unipc":
return self.step_unipc(
model_output=model_output,
timestep=timestep,
sample=sample,
return_dict=return_dict,
)
else:
raise NotImplementedError
def reset_scheduler_history(self):
self.model_outputs = [None] * self.config.solver_order
self.timestep_list = [None] * self.config.solver_order
self.lower_order_nums = 0
self.disable_corrector = self.config.disable_corrector
self.solver_p = self.config.solver_p
self.last_sample = None
self._step_index = None
self._begin_index = None
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -0,0 +1,331 @@
# Copyright 2025 The Helios Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..schedulers.scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
@dataclass
class HeliosDMDSchedulerOutput(BaseOutput):
prev_sample: torch.FloatTensor
model_outputs: torch.FloatTensor | None = None
last_sample: torch.FloatTensor | None = None
this_order: int | None = None
class HeliosDMDScheduler(SchedulerMixin, ConfigMixin):
_compatibles = []
order = 1
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
shift: float = 1.0, # Following Stable diffusion 3,
stages: int = 3,
stage_range: list = [0, 1 / 3, 2 / 3, 1],
gamma: float = 1 / 3,
prediction_type: str = "flow_prediction",
use_flow_sigmas: bool = True,
use_dynamic_shifting: bool = False,
time_shift_type: Literal["exponential", "linear"] = "linear",
):
self.timestep_ratios = {} # The timestep ratio for each stage
self.timesteps_per_stage = {} # The detailed timesteps per stage (fix max and min per stage)
self.sigmas_per_stage = {} # always uniform [1000, 0]
self.start_sigmas = {} # for start point / upsample renoise
self.end_sigmas = {} # for end point
self.ori_start_sigmas = {}
# self.init_sigmas()
self.init_sigmas_for_each_stage()
self.sigma_min = self.sigmas[-1].item()
self.sigma_max = self.sigmas[0].item()
self.gamma = gamma
self.last_sample = None
self._step_index = None
self._begin_index = None
def init_sigmas(self):
"""
initialize the global timesteps and sigmas
"""
num_train_timesteps = self.config.num_train_timesteps
shift = self.config.shift
alphas = np.linspace(1, 1 / num_train_timesteps, num_train_timesteps + 1)
sigmas = 1.0 - alphas
sigmas = np.flip(shift * sigmas / (1 + (shift - 1) * sigmas))[:-1].copy()
sigmas = torch.from_numpy(sigmas)
timesteps = (sigmas * num_train_timesteps).clone()
self._step_index = None
self._begin_index = None
self.timesteps = timesteps
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
def init_sigmas_for_each_stage(self):
"""
Init the timesteps for each stage
"""
self.init_sigmas()
stage_distance = []
stages = self.config.stages
training_steps = self.config.num_train_timesteps
stage_range = self.config.stage_range
# Init the start and end point of each stage
for i_s in range(stages):
# To decide the start and ends point
start_indice = int(stage_range[i_s] * training_steps)
start_indice = max(start_indice, 0)
end_indice = int(stage_range[i_s + 1] * training_steps)
end_indice = min(end_indice, training_steps)
start_sigma = self.sigmas[start_indice].item()
end_sigma = self.sigmas[end_indice].item() if end_indice < training_steps else 0.0
self.ori_start_sigmas[i_s] = start_sigma
if i_s != 0:
ori_sigma = 1 - start_sigma
gamma = self.config.gamma
corrected_sigma = (1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma)) * ori_sigma
# corrected_sigma = 1 / (2 - ori_sigma) * ori_sigma
start_sigma = 1 - corrected_sigma
stage_distance.append(start_sigma - end_sigma)
self.start_sigmas[i_s] = start_sigma
self.end_sigmas[i_s] = end_sigma
# Determine the ratio of each stage according to flow length
tot_distance = sum(stage_distance)
for i_s in range(stages):
if i_s == 0:
start_ratio = 0.0
else:
start_ratio = sum(stage_distance[:i_s]) / tot_distance
if i_s == stages - 1:
end_ratio = 0.9999999999999999
else:
end_ratio = sum(stage_distance[: i_s + 1]) / tot_distance
self.timestep_ratios[i_s] = (start_ratio, end_ratio)
# Determine the timesteps and sigmas for each stage
for i_s in range(stages):
timestep_ratio = self.timestep_ratios[i_s]
# timestep_max = self.timesteps[int(timestep_ratio[0] * training_steps)]
timestep_max = min(self.timesteps[int(timestep_ratio[0] * training_steps)], 999)
timestep_min = self.timesteps[min(int(timestep_ratio[1] * training_steps), training_steps - 1)]
timesteps = np.linspace(timestep_max, timestep_min, training_steps + 1)
self.timesteps_per_stage[i_s] = (
timesteps[:-1] if isinstance(timesteps, torch.Tensor) else torch.from_numpy(timesteps[:-1])
)
stage_sigmas = np.linspace(0.999, 0, training_steps + 1)
self.sigmas_per_stage[i_s] = torch.from_numpy(stage_sigmas[:-1])
@property
def step_index(self):
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
"""
return self._step_index
@property
def begin_index(self):
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
"""
return self._begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def _sigma_to_t(self, sigma):
return sigma * self.config.num_train_timesteps
def set_timesteps(
self,
num_inference_steps: int,
stage_index: int | None = None,
device: str | torch.device = None,
sigmas: bool | None = None,
mu: bool | None = None,
is_amplify_first_chunk: bool = False,
):
"""
Setting the timesteps and sigmas for each stage
"""
if is_amplify_first_chunk:
num_inference_steps = num_inference_steps * 2 + 1
else:
num_inference_steps = num_inference_steps + 1
self.num_inference_steps = num_inference_steps
self.init_sigmas()
if self.config.stages == 1:
if sigmas is None:
sigmas = np.linspace(1, 1 / self.config.num_train_timesteps, num_inference_steps + 1)[:-1].astype(
np.float32
)
if self.config.shift != 1.0:
assert not self.config.use_dynamic_shifting
sigmas = self.time_shift(self.config.shift, 1.0, sigmas)
timesteps = (sigmas * self.config.num_train_timesteps).copy()
sigmas = torch.from_numpy(sigmas)
else:
stage_timesteps = self.timesteps_per_stage[stage_index]
timesteps = np.linspace(
stage_timesteps[0].item(),
stage_timesteps[-1].item(),
num_inference_steps,
)
stage_sigmas = self.sigmas_per_stage[stage_index]
ratios = np.linspace(stage_sigmas[0].item(), stage_sigmas[-1].item(), num_inference_steps)
sigmas = torch.from_numpy(ratios)
self.timesteps = torch.from_numpy(timesteps).to(device=device)
self.sigmas = torch.cat([sigmas, torch.zeros(1)]).to(device=device)
self._step_index = None
self.reset_scheduler_history()
self.timesteps = self.timesteps[:-1]
self.sigmas = torch.cat([self.sigmas[:-2], self.sigmas[-1:]])
if self.config.use_dynamic_shifting:
assert self.config.shift == 1.0
self.sigmas = self.time_shift(mu, 1.0, self.sigmas)
if self.config.stages == 1:
self.timesteps = self.sigmas[:-1] * self.config.num_train_timesteps
else:
self.timesteps = self.timesteps_per_stage[stage_index].min() + self.sigmas[:-1] * (
self.timesteps_per_stage[stage_index].max() - self.timesteps_per_stage[stage_index].min()
)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler.time_shift
def time_shift(self, mu: float, sigma: float, t: torch.Tensor):
"""
Apply time shifting to the sigmas.
Args:
mu (`float`):
The mu parameter for the time shift.
sigma (`float`):
The sigma parameter for the time shift.
t (`torch.Tensor`):
The input timesteps.
Returns:
`torch.Tensor`:
The time-shifted timesteps.
"""
if self.config.time_shift_type == "exponential":
return self._time_shift_exponential(mu, sigma, t)
elif self.config.time_shift_type == "linear":
return self._time_shift_linear(mu, sigma, t)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_exponential
def _time_shift_exponential(self, mu, sigma, t):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
# Copied from diffusers.schedulers.scheduling_flow_match_euler_discrete.FlowMatchEulerDiscreteScheduler._time_shift_linear
def _time_shift_linear(self, mu, sigma, t):
return mu / (mu + (1 / t - 1) ** sigma)
# ---------------------------------- For DMD ----------------------------------
def add_noise(self, original_samples, noise, timestep, sigmas, timesteps):
sigmas = sigmas.to(noise.device)
timesteps = timesteps.to(noise.device)
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
sigma = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
sample = (1 - sigma) * original_samples + sigma * noise
return sample.type_as(noise)
def convert_flow_pred_to_x0(self, flow_pred, xt, timestep, sigmas, timesteps):
# use higher precision for calculations
original_dtype = flow_pred.dtype
device = flow_pred.device
flow_pred, xt, sigmas, timesteps = (x.double().to(device) for x in (flow_pred, xt, sigmas, timesteps))
timestep_id = torch.argmin((timesteps.unsqueeze(0) - timestep.unsqueeze(1)).abs(), dim=1)
sigma_t = sigmas[timestep_id].reshape(-1, 1, 1, 1, 1)
x0_pred = xt - sigma_t * flow_pred
return x0_pred.to(original_dtype)
def step(
self,
model_output: torch.FloatTensor,
timestep: float | torch.FloatTensor = None,
sample: torch.FloatTensor = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
cur_sampling_step: int = 0,
dmd_noisy_tensor: torch.FloatTensor | None = None,
dmd_sigmas: torch.FloatTensor | None = None,
dmd_timesteps: torch.FloatTensor | None = None,
all_timesteps: torch.FloatTensor | None = None,
) -> HeliosDMDSchedulerOutput | tuple:
pred_image_or_video = self.convert_flow_pred_to_x0(
flow_pred=model_output,
xt=sample,
timestep=torch.full((model_output.shape[0],), timestep, dtype=torch.long, device=model_output.device),
sigmas=dmd_sigmas,
timesteps=dmd_timesteps,
)
if cur_sampling_step < len(all_timesteps) - 1:
prev_sample = self.add_noise(
pred_image_or_video,
dmd_noisy_tensor,
torch.full(
(model_output.shape[0],),
all_timesteps[cur_sampling_step + 1],
dtype=torch.long,
device=model_output.device,
),
sigmas=dmd_sigmas,
timesteps=dmd_timesteps,
)
else:
prev_sample = pred_image_or_video
if not return_dict:
return (prev_sample,)
return HeliosDMDSchedulerOutput(prev_sample=prev_sample)
def reset_scheduler_history(self):
self._step_index = None
self._begin_index = None
def __len__(self):
return self.config.num_train_timesteps

View File

@@ -31,14 +31,18 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_train_timesteps (`int`, defaults to 1000):
The number of diffusion steps to train the model.
trained_betas (`np.ndarray`, *optional*):
trained_betas (`np.ndarray` or `List[float]`, *optional*):
Pass an array of betas directly to the constructor to bypass `beta_start` and `beta_end`.
"""
order = 1
@register_to_config
def __init__(self, num_train_timesteps: int = 1000, trained_betas: np.ndarray | list[float] | None = None):
def __init__(
self,
num_train_timesteps: int = 1000,
trained_betas: np.ndarray | list[float] | None = None,
):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)
@@ -56,21 +60,29 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
self._begin_index = None
@property
def step_index(self):
def step_index(self) -> int | None:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.
Returns:
`int` or `None`:
The index counter for current timestep.
"""
return self._step_index
@property
def begin_index(self):
def begin_index(self) -> int | None:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
Returns:
`int` or `None`:
The index for the first timestep.
"""
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
@@ -169,7 +181,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
@@ -228,7 +240,30 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
"""
return sample
def _get_prev_sample(self, sample, timestep_index, prev_timestep_index, ets):
def _get_prev_sample(
self,
sample: torch.Tensor,
timestep_index: int,
prev_timestep_index: int,
ets: torch.Tensor,
) -> torch.Tensor:
"""
Predicts the previous sample based on the current sample, timestep indices, and running model outputs.
Args:
sample (`torch.Tensor`):
The current sample.
timestep_index (`int`):
Index of the current timestep in the schedule.
prev_timestep_index (`int`):
Index of the previous timestep in the schedule.
ets (`torch.Tensor`):
The running sequence of model outputs.
Returns:
`torch.Tensor`:
The predicted previous sample.
"""
alpha = self.alphas[timestep_index]
sigma = self.betas[timestep_index]
@@ -240,5 +275,5 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
return prev_sample
def __len__(self):
def __len__(self) -> int:
return self.config.num_train_timesteps

View File

@@ -656,6 +656,21 @@ class AutoencoderOobleck(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderRAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderTiny(metaclass=DummyObject):
_backends = ["torch"]
@@ -1031,6 +1046,21 @@ class GlmImageTransformer2DModel(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HeliosTransformer3DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HiDreamImageTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -2743,6 +2773,36 @@ class FlowMatchLCMScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class HeliosDMDScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HeliosScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class HeunDiscreteScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1352,6 +1352,36 @@ class GlmImagePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class HeliosPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HeliosPyramidPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class HiDreamImagePipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -299,7 +299,10 @@ def get_cached_module_file(
# Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
if subfolder is not None:
module_file_or_url = os.path.join(pretrained_model_name_or_path, subfolder, module_file)
else:
module_file_or_url = os.path.join(pretrained_model_name_or_path, module_file)
if os.path.isfile(module_file_or_url):
resolved_module_file = module_file_or_url
@@ -384,7 +387,11 @@ def get_cached_module_file(
if not os.path.exists(submodule_path / module_folder):
os.makedirs(submodule_path / module_folder)
module_needed = f"{module_needed}.py"
shutil.copyfile(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
if subfolder is not None:
source_path = os.path.join(pretrained_model_name_or_path, subfolder, module_needed)
else:
source_path = os.path.join(pretrained_model_name_or_path, module_needed)
shutil.copyfile(source_path, submodule_path / module_needed)
else:
# Get the commit hash
# TODO: we will get this info in the etag soon, so retrieve it from there and not here.

View File

@@ -107,6 +107,7 @@ def load_or_create_model_card(
widget: list[dict] | None = None,
inference: bool | None = None,
is_modular: bool = False,
update_model_card: bool = False,
) -> ModelCard:
"""
Loads or creates a model card.
@@ -133,6 +134,9 @@ def load_or_create_model_card(
`load_or_create_model_card` from a training script.
is_modular: (`bool`, optional): Boolean flag to denote if the model card is for a modular pipeline.
When True, uses model_description as-is without additional template formatting.
update_model_card: (`bool`, optional): When True, regenerates the model card content even if one
already exists on the remote repo. Existing card metadata (tags, license, etc.) is preserved. Only
supported for modular pipelines (i.e., `is_modular=True`).
"""
if not is_jinja_available():
raise ValueError(
@@ -141,9 +145,17 @@ def load_or_create_model_card(
" To install it, please run `pip install Jinja2`."
)
if update_model_card and not is_modular:
raise ValueError("`update_model_card=True` is only supported for modular pipelines (`is_modular=True`).")
try:
# Check if the model card is present on the remote repo
model_card = ModelCard.load(repo_id_or_path, token=token)
# For modular pipelines, regenerate card content when requested (preserve existing metadata)
if update_model_card and is_modular and model_description is not None:
existing_data = model_card.data
model_card = ModelCard(model_description)
model_card.data = existing_data
except (EntryNotFoundError, RepositoryNotFoundError):
# Otherwise create a model card from template
if from_training:

View File

@@ -566,3 +566,127 @@ class GroupOffloadTests(unittest.TestCase):
"layers_per_block": 1,
}
return init_dict
# Model with conditionally-executed modules, simulating Helios patch_short/patch_mid/patch_long behavior.
# These modules are only called when optional inputs are provided, which means the lazy prefetch
# execution order tracer may not see them on the first forward pass. This can cause a device mismatch
# on subsequent calls when the modules ARE invoked but their weights were never onloaded.
# See: https://github.com/huggingface/diffusers/pull/13211
class DummyModelWithConditionalModules(ModelMixin):
def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(in_features, hidden_features)
self.activation = torch.nn.ReLU()
self.blocks = torch.nn.ModuleList(
[DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)]
)
self.linear_2 = torch.nn.Linear(hidden_features, out_features)
# These modules are only invoked when optional_input is not None.
# Output dimension matches hidden_features so they can be added after linear_1.
self.optional_proj_1 = torch.nn.Linear(in_features, hidden_features)
self.optional_proj_2 = torch.nn.Linear(in_features, hidden_features)
def forward(self, x: torch.Tensor, optional_input: torch.Tensor | None = None) -> torch.Tensor:
x = self.linear_1(x)
x = self.activation(x)
if optional_input is not None:
# Add optional projections after linear_1 so dimensions match (both hidden_features)
x = x + self.optional_proj_1(optional_input)
x = x + self.optional_proj_2(optional_input)
for block in self.blocks:
x = block(x)
x = self.linear_2(x)
return x
class ConditionalModuleGroupOffloadTests(GroupOffloadTests):
"""Tests for conditionally-executed modules under group offloading with streams.
Regression tests for the case where a module is not executed during the first forward pass
(when the lazy prefetch execution order is traced), but IS executed on subsequent passes.
Without the fix, the weights of such modules remain on CPU while the input is on GPU,
causing a RuntimeError about tensor device mismatch.
"""
def get_model(self):
torch.manual_seed(0)
return DummyModelWithConditionalModules(
in_features=self.in_features,
hidden_features=self.hidden_features,
out_features=self.out_features,
num_layers=self.num_layers,
)
@parameterized.expand([("leaf_level",), ("block_level",)])
@unittest.skipIf(
torch.device(torch_device).type not in ["cuda", "xpu"],
"Test requires a CUDA or XPU device.",
)
def test_conditional_modules_with_stream(self, offload_type: str):
"""Regression test: conditionally-executed modules must not cause device mismatch when using streams.
The model contains two optional Linear layers (optional_proj_1, optional_proj_2) that are only
executed when `optional_input` is provided. This simulates modules like patch_short/patch_mid/
patch_long in HeliosTransformer3DModel, which are only called when history latents are present.
When using streams, `LazyPrefetchGroupOffloadingHook` traces the execution order on the first
forward pass and sets up a prefetch chain so each module pre-loads the next one's weights.
Modules not executed during this tracing pass are excluded from the prefetch chain.
The bug: if a module was absent from the first (tracing) pass, its `onload_self` flag gets set
to False (meaning "someone else will onload me"). But since it's not in the prefetch chain,
nobody ever does — so its weights remain on CPU. When the module is eventually called in a
subsequent pass, the input is on GPU but the weights are on CPU, causing a RuntimeError.
We therefore must invoke the model multiple times:
1. First pass WITHOUT optional_input: triggers the lazy prefetch tracing. optional_proj_1/2
are absent, so they are excluded from the prefetch chain.
2. Second pass WITH optional_input: the regression case. Without the fix, this raises a
RuntimeError because optional_proj_1/2 weights are still on CPU.
3. Third pass WITHOUT optional_input: verifies the model remains stable after having seen
both code paths.
"""
model = self.get_model()
model_ref = self.get_model()
model_ref.load_state_dict(model.state_dict(), strict=True)
model_ref.to(torch_device)
model.enable_group_offload(
torch_device,
offload_type=offload_type,
num_blocks_per_group=1,
use_stream=True,
)
x = torch.randn(4, self.in_features).to(torch_device)
optional_input = torch.randn(4, self.in_features).to(torch_device)
with torch.no_grad():
# First forward pass WITHOUT optional_input — this is when the lazy prefetch
# execution order is traced. optional_proj_1/2 are NOT in the traced order.
out_ref_no_opt = model_ref(x, optional_input=None)
out_no_opt = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt, out_no_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on first pass (no optional_input).",
)
# Second forward pass WITH optional_input — optional_proj_1/2 ARE now called.
out_ref_with_opt = model_ref(x, optional_input=optional_input)
out_with_opt = model(x, optional_input=optional_input)
self.assertTrue(
torch.allclose(out_ref_with_opt, out_with_opt, atol=1e-5),
f"[{offload_type}] Outputs do not match on second pass (with optional_input).",
)
# Third pass again without optional_input — verify stable behavior.
out_ref_no_opt2 = model_ref(x, optional_input=None)
out_no_opt2 = model(x, optional_input=None)
self.assertTrue(
torch.allclose(out_ref_no_opt2, out_no_opt2, atol=1e-5),
f"[{offload_type}] Outputs do not match on third pass (back to no optional_input).",
)

View File

@@ -0,0 +1,120 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import torch
from transformers import AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, FlowMatchEulerDiscreteScheduler, HeliosPipeline, HeliosTransformer3DModel
from ..testing_utils import floats_tensor, require_peft_backend, skip_mps
sys.path.append(".")
from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
class HeliosLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = HeliosPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
scheduler_kwargs = {}
transformer_kwargs = {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 16,
"out_channels": 16,
"text_dim": 32,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 2,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_dim": (4, 4, 4),
"has_multi_term_memory_patch": True,
"guidance_cross_attn": True,
"zero_history_timestep": True,
"is_amplify_history": False,
}
transformer_cls = HeliosTransformer3DModel
vae_kwargs = {
"base_dim": 3,
"z_dim": 16,
"dim_mult": [1, 1, 1, 1],
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}
vae_cls = AutoencoderKLWan
has_two_text_encoders = True
tokenizer_cls, tokenizer_id = AutoTokenizer, "hf-internal-testing/tiny-random-t5"
text_encoder_cls, text_encoder_id = T5EncoderModel, "hf-internal-testing/tiny-random-t5"
text_encoder_target_modules = ["q", "k", "v", "o"]
supports_text_encoder_loras = False
@property
def output_shape(self):
return (1, 33, 32, 32, 3)
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 16
num_channels = 4
num_frames = 9
num_latent_frames = 3 # (num_frames - 1) // temporal_compression_ratio + 1
sizes = (4, 4)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_latent_frames, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "",
"num_frames": num_frames,
"num_inference_steps": 1,
"guidance_scale": 6.0,
"height": 32,
"width": 32,
"max_sequence_length": sequence_length,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def test_simple_inference_with_text_lora_denoiser_fused_multi(self):
super().test_simple_inference_with_text_lora_denoiser_fused_multi(expected_atol=9e-3)
def test_simple_inference_with_text_denoiser_lora_unfused(self):
super().test_simple_inference_with_text_denoiser_lora_unfused(expected_atol=9e-3)
@unittest.skip("Not supported in Helios.")
def test_simple_inference_with_text_denoiser_block_scale(self):
pass
@unittest.skip("Not supported in Helios.")
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
pass
@unittest.skip("Not supported in Helios.")
def test_modify_padding_mode(self):
pass

View File

@@ -0,0 +1,300 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import pytest
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import to_tensor
import diffusers.models.autoencoders.autoencoder_rae as _rae_module
from diffusers.models.autoencoders.autoencoder_rae import (
_ENCODER_FORWARD_FNS,
AutoencoderRAE,
_build_encoder,
)
from diffusers.utils import load_image
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
slow,
torch_all_close,
torch_device,
)
from ..testing_utils import BaseModelTesterConfig, ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
# ---------------------------------------------------------------------------
# Tiny test encoder for fast unit tests (no transformers dependency)
# ---------------------------------------------------------------------------
class _TinyTestEncoderModule(torch.nn.Module):
"""Minimal encoder that mimics the patch-token interface without any HF model."""
def __init__(self, hidden_size: int = 16, patch_size: int = 8, **kwargs):
super().__init__()
self.patch_size = patch_size
self.hidden_size = hidden_size
def forward(self, images: torch.Tensor) -> torch.Tensor:
pooled = F.avg_pool2d(images.mean(dim=1, keepdim=True), kernel_size=self.patch_size, stride=self.patch_size)
tokens = pooled.flatten(2).transpose(1, 2).contiguous()
return tokens.repeat(1, 1, self.hidden_size)
def _tiny_test_encoder_forward(model, images):
return model(images)
def _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
return _TinyTestEncoderModule(hidden_size=hidden_size, patch_size=patch_size)
# Monkey-patch the dispatch tables so "tiny_test" is recognised by AutoencoderRAE
_ENCODER_FORWARD_FNS["tiny_test"] = _tiny_test_encoder_forward
_original_build_encoder = _build_encoder
def _patched_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers):
if encoder_type == "tiny_test":
return _build_tiny_test_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
return _original_build_encoder(encoder_type, hidden_size, patch_size, num_hidden_layers)
_rae_module._build_encoder = _patched_build_encoder
# ---------------------------------------------------------------------------
# Test config
# ---------------------------------------------------------------------------
class AutoencoderRAETesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderRAE
@property
def output_shape(self):
return (3, 16, 16)
def get_init_dict(self):
return {
"encoder_type": "tiny_test",
"encoder_hidden_size": 16,
"encoder_patch_size": 8,
"encoder_input_size": 32,
"patch_size": 4,
"image_size": 16,
"decoder_hidden_size": 32,
"decoder_num_hidden_layers": 1,
"decoder_num_attention_heads": 4,
"decoder_intermediate_size": 64,
"num_channels": 3,
"encoder_norm_mean": [0.5, 0.5, 0.5],
"encoder_norm_std": [0.5, 0.5, 0.5],
"noise_tau": 0.0,
"reshape_to_2d": True,
"scaling_factor": 1.0,
}
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_dummy_inputs(self):
return {"sample": torch.randn(2, 3, 32, 32, generator=self.generator, device="cpu").to(torch_device)}
# Bridge for AutoencoderTesterMixin which still uses the old interface
def prepare_init_args_and_inputs_for_common(self):
return self.get_init_dict(), self.get_dummy_inputs()
def _make_model(self, **overrides) -> AutoencoderRAE:
config = self.get_init_dict()
config.update(overrides)
return AutoencoderRAE(**config).to(torch_device)
class TestAutoEncoderRAE(AutoencoderRAETesterConfig, ModelTesterMixin):
"""Core model tests for AutoencoderRAE."""
@pytest.mark.skip(reason="AutoencoderRAE does not support torch dynamo yet")
def test_from_save_pretrained_dynamo(self): ...
def test_fast_encode_decode_and_forward_shapes(self):
model = self._make_model().eval()
x = torch.rand(2, 3, 32, 32, device=torch_device)
with torch.no_grad():
z = model.encode(x).latent
decoded = model.decode(z).sample
recon = model(x).sample
assert z.shape == (2, 16, 4, 4)
assert decoded.shape == (2, 3, 16, 16)
assert recon.shape == (2, 3, 16, 16)
assert torch.isfinite(recon).all().item()
def test_fast_scaling_factor_encode_and_decode_consistency(self):
torch.manual_seed(0)
model_base = self._make_model(scaling_factor=1.0).eval()
torch.manual_seed(0)
model_scaled = self._make_model(scaling_factor=2.0).eval()
x = torch.rand(2, 3, 32, 32, device=torch_device)
with torch.no_grad():
z_base = model_base.encode(x).latent
z_scaled = model_scaled.encode(x).latent
recon_base = model_base.decode(z_base).sample
recon_scaled = model_scaled.decode(z_scaled).sample
assert torch.allclose(z_scaled, z_base * 2.0, atol=1e-5, rtol=1e-4)
assert torch.allclose(recon_scaled, recon_base, atol=1e-5, rtol=1e-4)
def test_fast_latents_normalization_matches_formula(self):
latents_mean = torch.full((1, 16, 1, 1), 0.25, dtype=torch.float32)
latents_std = torch.full((1, 16, 1, 1), 2.0, dtype=torch.float32)
model_raw = self._make_model().eval()
model_norm = self._make_model(latents_mean=latents_mean, latents_std=latents_std).eval()
x = torch.rand(1, 3, 32, 32, device=torch_device)
with torch.no_grad():
z_raw = model_raw.encode(x).latent
z_norm = model_norm.encode(x).latent
expected = (z_raw - latents_mean.to(z_raw.device, z_raw.dtype)) / (
latents_std.to(z_raw.device, z_raw.dtype) + 1e-5
)
assert torch.allclose(z_norm, expected, atol=1e-5, rtol=1e-4)
def test_fast_slicing_matches_non_slicing(self):
model = self._make_model().eval()
x = torch.rand(3, 3, 32, 32, device=torch_device)
with torch.no_grad():
model.use_slicing = False
z_no_slice = model.encode(x).latent
out_no_slice = model.decode(z_no_slice).sample
model.use_slicing = True
z_slice = model.encode(x).latent
out_slice = model.decode(z_slice).sample
assert torch.allclose(z_slice, z_no_slice, atol=1e-6, rtol=1e-5)
assert torch.allclose(out_slice, out_no_slice, atol=1e-6, rtol=1e-5)
def test_fast_noise_tau_applies_only_in_train(self):
model = self._make_model(noise_tau=0.5).to(torch_device)
x = torch.rand(2, 3, 32, 32, device=torch_device)
model.train()
torch.manual_seed(0)
z_train_1 = model.encode(x).latent
torch.manual_seed(1)
z_train_2 = model.encode(x).latent
model.eval()
torch.manual_seed(0)
z_eval_1 = model.encode(x).latent
torch.manual_seed(1)
z_eval_2 = model.encode(x).latent
assert z_train_1.shape == z_eval_1.shape
assert not torch.allclose(z_train_1, z_train_2)
assert torch.allclose(z_eval_1, z_eval_2, atol=1e-6, rtol=1e-5)
class TestAutoEncoderRAESlicingTiling(AutoencoderRAETesterConfig, AutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderRAE."""
@slow
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
class AutoencoderRAEEncoderIntegrationTests:
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_dinov2_encoder_forward_shape(self):
encoder = _build_encoder("dinov2", hidden_size=768, patch_size=14, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["dinov2"](encoder, x)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 256 # (224/14)^2 - 5 (CLS + 4 register) = 251? Actually dinov2 has 256 patches
assert y.shape[2] == 768
def test_siglip2_encoder_forward_shape(self):
encoder = _build_encoder("siglip2", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["siglip2"](encoder, x)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 196 # (224/16)^2
assert y.shape[2] == 768
def test_mae_encoder_forward_shape(self):
encoder = _build_encoder("mae", hidden_size=768, patch_size=16, num_hidden_layers=12).to(torch_device)
x = torch.rand(1, 3, 224, 224, device=torch_device)
y = _ENCODER_FORWARD_FNS["mae"](encoder, x, patch_size=16)
assert y.ndim == 3
assert y.shape[0] == 1
assert y.shape[1] == 196 # (224/16)^2
assert y.shape[2] == 768
@slow
@pytest.mark.skip(reason="Not enough model usage to justify slow tests yet.")
class AutoencoderRAEIntegrationTests:
def teardown_method(self):
gc.collect()
backend_empty_cache(torch_device)
def test_autoencoder_rae_from_pretrained_dinov2(self):
model = AutoencoderRAE.from_pretrained("nyu-visionx/RAE-dinov2-wReg-base-ViTXL-n08").to(torch_device)
model.eval()
image = load_image(
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cat.png"
)
image = image.convert("RGB").resize((224, 224))
x = to_tensor(image).unsqueeze(0).to(torch_device)
with torch.no_grad():
latents = model.encode(x).latent
assert latents.shape == (1, 768, 16, 16)
recon = model.decode(latents).sample
assert recon.shape == (1, 3, 256, 256)
assert torch.isfinite(recon).all().item()
# fmt: off
expected_latent_slice = torch.tensor([0.7617, 0.8824, -0.4891])
expected_recon_slice = torch.tensor([0.1263, 0.1355, 0.1435])
# fmt: on
assert torch_all_close(latents[0, :3, 0, 0].float().cpu(), expected_latent_slice, atol=1e-3)
assert torch_all_close(recon[0, 0, 0, :3].float().cpu(), expected_recon_slice, atol=1e-3)

View File

@@ -1,9 +1,15 @@
import json
import os
import tempfile
import unittest
from unittest.mock import MagicMock, patch
import torch
from transformers import CLIPTextModel, LongformerModel
from diffusers import ConfigMixin
from diffusers.models import AutoModel, UNet2DConditionModel
from diffusers.models.modeling_utils import ModelMixin
class TestAutoModel(unittest.TestCase):
@@ -35,6 +41,45 @@ class TestAutoModel(unittest.TestCase):
)
assert isinstance(model, CLIPTextModel)
def test_load_dynamic_module_from_local_path_with_subfolder(self):
CUSTOM_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class CustomModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=8):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with tempfile.TemporaryDirectory() as tmpdir:
subfolder = "custom_model"
model_dir = os.path.join(tmpdir, subfolder)
os.makedirs(model_dir)
with open(os.path.join(model_dir, "modeling.py"), "w") as f:
f.write(CUSTOM_MODEL_CODE)
config = {
"_class_name": "CustomModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.CustomModel"},
"hidden_size": 8,
}
with open(os.path.join(model_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save({}, os.path.join(model_dir, "diffusion_pytorch_model.bin"))
model = AutoModel.from_pretrained(tmpdir, subfolder=subfolder, trust_remote_code=True)
assert model.__class__.__name__ == "CustomModel"
assert model.config["hidden_size"] == 8
class TestAutoModelFromConfig(unittest.TestCase):
@patch(
@@ -100,3 +145,51 @@ class TestAutoModelFromConfig(unittest.TestCase):
def test_from_config_raises_on_none(self):
with self.assertRaises(ValueError, msg="Please provide a `pretrained_model_name_or_path_or_dict`"):
AutoModel.from_config(None)
class TestRegisterForAutoClass(unittest.TestCase):
def test_register_for_auto_class_sets_attribute(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
DummyModel.register_for_auto_class("AutoModel")
self.assertEqual(DummyModel._auto_class, "AutoModel")
def test_register_for_auto_class_rejects_unsupported(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
with self.assertRaises(ValueError, msg="Only 'AutoModel' is supported"):
DummyModel.register_for_auto_class("AutoPipeline")
def test_auto_map_in_saved_config(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
DummyModel.register_for_auto_class("AutoModel")
model = DummyModel()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
self.assertIn("auto_map", config)
self.assertIn("AutoModel", config["auto_map"])
module_name = DummyModel.__module__.split(".")[-1]
self.assertEqual(config["auto_map"]["AutoModel"], f"{module_name}.DummyModel")
def test_no_auto_map_without_register(self):
class DummyModel(ModelMixin, ConfigMixin):
config_name = "config.json"
model = DummyModel()
with tempfile.TemporaryDirectory() as tmpdir:
model.save_config(tmpdir)
config_path = os.path.join(tmpdir, "config.json")
with open(config_path, "r") as f:
config = json.load(f)
self.assertNotIn("auto_map", config)

View File

@@ -0,0 +1,168 @@
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import torch
from diffusers import HeliosTransformer3DModel
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
MemoryTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class HeliosTransformer3DTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return HeliosTransformer3DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-helios-base-transformer"
@property
def output_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def input_shape(self) -> tuple[int, ...]:
return (4, 2, 16, 16)
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int] | tuple | str | bool]:
return {
"patch_size": (1, 2, 2),
"num_attention_heads": 2,
"attention_head_dim": 12,
"in_channels": 4,
"out_channels": 4,
"text_dim": 16,
"freq_dim": 256,
"ffn_dim": 32,
"num_layers": 2,
"cross_attn_norm": True,
"qk_norm": "rms_norm_across_heads",
"rope_dim": (4, 4, 4),
"has_multi_term_memory_patch": True,
"guidance_cross_attn": True,
"zero_history_timestep": True,
"is_amplify_history": False,
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
num_channels = 4
num_frames = 2
height = 16
width = 16
text_encoder_embedding_dim = 16
sequence_length = 12
hidden_states = randn_tensor(
(batch_size, num_channels, num_frames, height, width),
generator=self.generator,
device=torch_device,
)
timestep = torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, text_encoder_embedding_dim),
generator=self.generator,
device=torch_device,
)
indices_hidden_states = torch.ones((batch_size, num_frames)).to(torch_device)
indices_latents_history_short = torch.ones((batch_size, num_frames - 1)).to(torch_device)
indices_latents_history_mid = torch.ones((batch_size, num_frames - 1)).to(torch_device)
indices_latents_history_long = torch.ones((batch_size, (num_frames - 1) * 4)).to(torch_device)
latents_history_short = randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
)
latents_history_mid = randn_tensor(
(batch_size, num_channels, num_frames - 1, height, width),
generator=self.generator,
device=torch_device,
)
latents_history_long = randn_tensor(
(batch_size, num_channels, (num_frames - 1) * 4, height, width),
generator=self.generator,
device=torch_device,
)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
"indices_hidden_states": indices_hidden_states,
"indices_latents_history_short": indices_latents_history_short,
"indices_latents_history_mid": indices_latents_history_mid,
"indices_latents_history_long": indices_latents_history_long,
"latents_history_short": latents_history_short,
"latents_history_mid": latents_history_mid,
"latents_history_long": latents_history_long,
}
class TestHeliosTransformer3D(HeliosTransformer3DTesterConfig, ModelTesterMixin):
"""Core model tests for Helios Transformer 3D."""
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16], ids=["fp16", "bf16"])
def test_from_save_pretrained_dtype_inference(self, tmp_path, dtype):
# Skip: fp16/bf16 require very high atol to pass, providing little signal.
# Dtype preservation is already tested by test_from_save_pretrained_dtype and test_keep_in_fp32_modules.
pytest.skip("Tolerance requirements too high for meaningful test")
class TestHeliosTransformer3DMemory(HeliosTransformer3DTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Helios Transformer 3D."""
class TestHeliosTransformer3DTraining(HeliosTransformer3DTesterConfig, TrainingTesterMixin):
"""Training tests for Helios Transformer 3D."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"HeliosTransformer3DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestHeliosTransformer3DAttention(HeliosTransformer3DTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Helios Transformer 3D."""
class TestHeliosTransformer3DCompile(HeliosTransformer3DTesterConfig, TorchCompileTesterMixin):
"""Torch compile tests for Helios Transformer 3D."""
@pytest.mark.xfail(
reason="Helios DiT does not compile when deterministic algorithms are used due to https://github.com/pytorch/pytorch/issues/170079"
)
def test_torch_compile_recompilation_and_graph_break(self):
super().test_torch_compile_recompilation_and_graph_break()

View File

@@ -10,6 +10,11 @@ import torch
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -19,7 +24,13 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
)
from diffusers.utils import logging
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, require_accelerator, torch_device
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
torch_device,
)
class ModularPipelineTesterMixin:
@@ -429,6 +440,117 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(tmp_path)
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(tmp_path)
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:
@@ -483,8 +605,7 @@ class TestModularModelCardContent:
"blocks_description",
"components_description",
"configs_section",
"inputs_description",
"outputs_description",
"io_specification_section",
"trigger_inputs_section",
"tags",
]
@@ -581,18 +702,19 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(inputs=inputs)
content = generate_modular_model_card_content(blocks)
assert "**Required:**" in content["inputs_description"]
assert "**Optional:**" in content["inputs_description"]
assert "prompt" in content["inputs_description"]
assert "num_steps" in content["inputs_description"]
assert "default: `50`" in content["inputs_description"]
io_section = content["io_specification_section"]
assert "**Inputs:**" in io_section
assert "prompt" in io_section
assert "num_steps" in io_section
assert "*optional*" in io_section
assert "defaults to `50`" in io_section
def test_inputs_description_empty(self):
"""Test handling of pipelines without specific inputs."""
blocks = self.create_mock_blocks(inputs=[])
content = generate_modular_model_card_content(blocks)
assert "No specific inputs defined" in content["inputs_description"]
assert "No specific inputs defined" in content["io_specification_section"]
def test_outputs_description_formatting(self):
"""Test that outputs are correctly formatted."""
@@ -602,15 +724,16 @@ class TestModularModelCardContent:
blocks = self.create_mock_blocks(outputs=outputs)
content = generate_modular_model_card_content(blocks)
assert "images" in content["outputs_description"]
assert "Generated images" in content["outputs_description"]
io_section = content["io_specification_section"]
assert "images" in io_section
assert "Generated images" in io_section
def test_outputs_description_empty(self):
"""Test handling of pipelines without specific outputs."""
blocks = self.create_mock_blocks(outputs=[])
content = generate_modular_model_card_content(blocks)
assert "Standard pipeline outputs" in content["outputs_description"]
assert "Standard pipeline outputs" in content["io_specification_section"]
def test_trigger_inputs_section_with_triggers(self):
"""Test that trigger inputs section is generated when present."""
@@ -628,35 +751,6 @@ class TestModularModelCardContent:
assert content["trigger_inputs_section"] == ""
def test_blocks_description_with_sub_blocks(self):
"""Test that blocks with sub-blocks are correctly described."""
class MockBlockWithSubBlocks:
def __init__(self):
self.__class__.__name__ = "ParentBlock"
self.description = "Parent block"
self.sub_blocks = {
"child1": self.create_child_block("ChildBlock1", "Child 1 description"),
"child2": self.create_child_block("ChildBlock2", "Child 2 description"),
}
def create_child_block(self, name, desc):
class ChildBlock:
def __init__(self):
self.__class__.__name__ = name
self.description = desc
return ChildBlock()
blocks = self.create_mock_blocks()
blocks.sub_blocks["parent"] = MockBlockWithSubBlocks()
content = generate_modular_model_card_content(blocks)
assert "parent" in content["blocks_description"]
assert "child1" in content["blocks_description"]
assert "child2" in content["blocks_description"]
def test_model_description_includes_block_count(self):
"""Test that model description includes the number of blocks."""
blocks = self.create_mock_blocks(num_blocks=5)
@@ -715,6 +809,18 @@ class TestLoadComponentsSkipBehavior:
assert pipe.unet is not None
assert getattr(pipe, "vae", None) is None
def test_load_components_selective_loading_incremental(self):
"""Loading a subset of components should not affect already-loaded components."""
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(names="unet", torch_dtype=torch.float32)
pipe.load_components(names="text_encoder", torch_dtype=torch.float32)
assert hasattr(pipe, "unet")
assert pipe.unet is not None
assert hasattr(pipe, "text_encoder")
assert pipe.text_encoder is not None
def test_load_components_skips_invalid_pretrained_path(self):
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
@@ -730,6 +836,112 @@ class TestLoadComponentsSkipBehavior:
assert not hasattr(pipe, "test_component") or pipe.test_component is None
class TestCustomModelSavePretrained:
def test_save_pretrained_updates_index_for_local_model(self, tmp_path):
"""When a component without _diffusers_load_id (custom/local model) is saved,
modular_model_index.json should point to the save directory."""
import json
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
pipe.unet._diffusers_load_id = "null"
save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir)
with open(os.path.join(save_dir, "modular_model_index.json")) as f:
index = json.load(f)
_library, _cls, unet_spec = index["unet"]
assert unet_spec["pretrained_model_name_or_path"] == save_dir
assert unet_spec["subfolder"] == "unet"
_library, _cls, vae_spec = index["vae"]
assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
def test_save_pretrained_roundtrip_with_local_model(self, tmp_path):
"""A pipeline with a custom/local model should be saveable and re-loadable with identical outputs."""
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
pipe.unet._diffusers_load_id = "null"
original_state_dict = pipe.unet.state_dict()
save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
assert loaded_pipe.unet is not None
assert loaded_pipe.unet.__class__.__name__ == pipe.unet.__class__.__name__
loaded_state_dict = loaded_pipe.unet.state_dict()
assert set(original_state_dict.keys()) == set(loaded_state_dict.keys())
for key in original_state_dict:
assert torch.equal(original_state_dict[key], loaded_state_dict[key]), f"Mismatch in {key}"
def test_save_pretrained_updates_index_for_model_with_no_load_id(self, tmp_path):
"""testing the workflow of update the pipeline with a custom model and save the pipeline,
the modular_model_index.json should point to the save directory."""
import json
from diffusers import UNet2DConditionModel
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
unet = UNet2DConditionModel.from_pretrained(
"hf-internal-testing/tiny-stable-diffusion-xl-pipe", subfolder="unet"
)
assert not hasattr(unet, "_diffusers_load_id")
pipe.update_components(unet=unet)
save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir)
with open(os.path.join(save_dir, "modular_model_index.json")) as f:
index = json.load(f)
_library, _cls, unet_spec = index["unet"]
assert unet_spec["pretrained_model_name_or_path"] == save_dir
assert unet_spec["subfolder"] == "unet"
_library, _cls, vae_spec = index["vae"]
assert vae_spec["pretrained_model_name_or_path"] == "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
def test_save_pretrained_overwrite_modular_index(self, tmp_path):
"""With overwrite_modular_index=True, all component references should point to the save directory."""
import json
pipe = ModularPipeline.from_pretrained("hf-internal-testing/tiny-stable-diffusion-xl-pipe")
pipe.load_components(torch_dtype=torch.float32)
save_dir = str(tmp_path / "my-pipeline")
pipe.save_pretrained(save_dir, overwrite_modular_index=True)
with open(os.path.join(save_dir, "modular_model_index.json")) as f:
index = json.load(f)
for component_name in ["unet", "vae", "text_encoder", "text_encoder_2"]:
if component_name not in index:
continue
_library, _cls, spec = index[component_name]
assert spec["pretrained_model_name_or_path"] == save_dir, (
f"{component_name} should point to save dir but got {spec['pretrained_model_name_or_path']}"
)
assert spec["subfolder"] == component_name
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
assert loaded_pipe.unet is not None
assert loaded_pipe.vae is not None
class TestModularPipelineInitFallback:
"""Test that ModularPipeline.__init__ falls back to default_blocks_name when
_blocks_class_name is a base class (e.g. SequentialPipelineBlocks saved by from_blocks_dict)."""

View File

@@ -192,6 +192,156 @@ class TestModularCustomBlocks:
assert len(pipe.components) == 1
assert pipe.component_names[0] == "transformer"
def test_trust_remote_code_not_propagated_to_external_repo(self):
"""When a modular pipeline repo references a component from an external repo that has custom
code (auto_map in config), calling load_components(trust_remote_code=True) should NOT
propagate trust_remote_code to that external component. The external component should fail
to load."""
from diffusers import ModularPipeline
CUSTOM_MODEL_CODE = (
"import torch\n"
"from diffusers import ModelMixin, ConfigMixin\n"
"from diffusers.configuration_utils import register_to_config\n"
"\n"
"class CustomModel(ModelMixin, ConfigMixin):\n"
" @register_to_config\n"
" def __init__(self, hidden_size=8):\n"
" super().__init__()\n"
" self.linear = torch.nn.Linear(hidden_size, hidden_size)\n"
"\n"
" def forward(self, x):\n"
" return self.linear(x)\n"
)
with tempfile.TemporaryDirectory() as external_repo_dir, tempfile.TemporaryDirectory() as pipeline_repo_dir:
# Step 1: Create an external model repo with custom code (requires trust_remote_code)
with open(os.path.join(external_repo_dir, "modeling.py"), "w") as f:
f.write(CUSTOM_MODEL_CODE)
config = {
"_class_name": "CustomModel",
"_diffusers_version": "0.0.0",
"auto_map": {"AutoModel": "modeling.CustomModel"},
"hidden_size": 8,
}
with open(os.path.join(external_repo_dir, "config.json"), "w") as f:
json.dump(config, f)
torch.save({}, os.path.join(external_repo_dir, "diffusion_pytorch_model.bin"))
# Step 2: Create a custom block that references the external repo.
# Define both the class (for direct use) and its code string (for block.py).
class ExternalRefBlock(ModularPipelineBlocks):
@property
def expected_components(self):
return [
ComponentSpec(
"custom_model",
AutoModel,
pretrained_model_name_or_path=external_repo_dir,
)
]
@property
def inputs(self) -> List[InputParam]:
return [InputParam("prompt", type_hint=str, required=True)]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [OutputParam("output", type_hint=str)]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.output = "test"
self.set_block_state(state, block_state)
return components, state
EXTERNAL_REF_BLOCK_CODE_STR = (
"from typing import List\n"
"from diffusers import AutoModel\n"
"from diffusers.modular_pipelines import (\n"
" ComponentSpec,\n"
" InputParam,\n"
" ModularPipelineBlocks,\n"
" OutputParam,\n"
" PipelineState,\n"
")\n"
"\n"
"class ExternalRefBlock(ModularPipelineBlocks):\n"
" @property\n"
" def expected_components(self):\n"
" return [\n"
" ComponentSpec(\n"
' "custom_model",\n'
" AutoModel,\n"
f' pretrained_model_name_or_path="{external_repo_dir}",\n'
" )\n"
" ]\n"
"\n"
" @property\n"
" def inputs(self) -> List[InputParam]:\n"
' return [InputParam("prompt", type_hint=str, required=True)]\n'
"\n"
" @property\n"
" def intermediate_inputs(self) -> List[InputParam]:\n"
" return []\n"
"\n"
" @property\n"
" def intermediate_outputs(self) -> List[OutputParam]:\n"
' return [OutputParam("output", type_hint=str)]\n'
"\n"
" def __call__(self, components, state: PipelineState) -> PipelineState:\n"
" block_state = self.get_block_state(state)\n"
' block_state.output = "test"\n'
" self.set_block_state(state, block_state)\n"
" return components, state\n"
)
# Save the block config, write block.py, then load back via from_pretrained
block = ExternalRefBlock()
block.save_pretrained(pipeline_repo_dir)
# auto_map will reference the module name derived from ExternalRefBlock.__module__,
# which is "test_modular_pipelines_custom_blocks". Write the code file with that name.
code_path = os.path.join(pipeline_repo_dir, "test_modular_pipelines_custom_blocks.py")
with open(code_path, "w") as f:
f.write(EXTERNAL_REF_BLOCK_CODE_STR)
block = ModularPipelineBlocks.from_pretrained(pipeline_repo_dir, trust_remote_code=True)
pipe = block.init_pipeline()
pipe.save_pretrained(pipeline_repo_dir)
# Step 3: Load the pipeline from the saved directory.
loaded_pipe = ModularPipeline.from_pretrained(pipeline_repo_dir, trust_remote_code=True)
assert loaded_pipe._pretrained_model_name_or_path == pipeline_repo_dir
assert loaded_pipe._component_specs["custom_model"].pretrained_model_name_or_path == external_repo_dir
assert getattr(loaded_pipe, "custom_model", None) is None
# Step 4a: load_components WITHOUT trust_remote_code.
# It should still fail
loaded_pipe.load_components()
assert getattr(loaded_pipe, "custom_model", None) is None
# Step 4b: load_components with trust_remote_code=True.
# trust_remote_code should be stripped for the external component, so it fails.
# The warning should contain guidance about manually loading with trust_remote_code.
loaded_pipe.load_components(trust_remote_code=True)
assert getattr(loaded_pipe, "custom_model", None) is None
# Step 4c: Manually load with AutoModel and update_components — this should work.
from diffusers import AutoModel
custom_model = AutoModel.from_pretrained(external_repo_dir, trust_remote_code=True)
loaded_pipe.update_components(custom_model=custom_model)
assert getattr(loaded_pipe, "custom_model", None) is not None
def test_custom_block_loads_from_hub(self):
repo_id = "hf-internal-testing/tiny-modular-diffusers-block"
block = ModularPipelineBlocks.from_pretrained(repo_id, trust_remote_code=True)

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import tempfile
import unittest
@@ -306,96 +305,3 @@ class ConfigTester(unittest.TestCase):
result = json.loads(json_string)
assert result["test_file_1"] == config.config.test_file_1.as_posix()
assert result["test_file_2"] == config.config.test_file_2.as_posix()
class SampleObjectTyped(ConfigMixin):
config_name = "config.json"
@register_to_config
def __init__(
self,
a: int = 2,
b: int = 5,
c: str = "hello",
):
pass
class SampleObjectWithIgnore(ConfigMixin):
config_name = "config.json"
ignore_for_config = ["secret"]
@register_to_config
def __init__(
self,
a: int = 2,
secret: str = "hidden",
):
pass
class DataclassFromConfigTester(unittest.TestCase):
def test_get_dataclass_from_config_returns_frozen_dataclass(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert dataclasses.is_dataclass(tc)
with self.assertRaises(dataclasses.FrozenInstanceError):
tc.a = 99
def test_get_dataclass_from_config_class_name(self):
obj = SampleObject()
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert type(tc).__name__ == "SampleObjectConfig"
def test_get_dataclass_from_config_values_match_config(self):
obj = SampleObject(a=10, b=20)
tc = SampleObject._get_dataclass_from_config(dict(obj.config))
assert tc.a == 10
assert tc.b == 20
assert tc.c == (2, 5)
assert tc.d == "for diffusion"
assert tc.e == [1, 3]
def test_get_dataclass_from_config_from_raw_dict(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
assert tc.a == 7
assert tc.b == 3
assert tc.c == "world"
def test_get_dataclass_from_config_annotations(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "hi"})
fields = {f.name: f.type for f in dataclasses.fields(tc)}
assert fields["a"] is int
assert fields["b"] is int
assert fields["c"] is str
def test_get_dataclass_from_config_asdict_roundtrip(self):
tc = SampleObjectTyped._get_dataclass_from_config({"a": 7, "b": 3, "c": "world"})
d = dataclasses.asdict(tc)
assert d == {"a": 7, "b": 3, "c": "world"}
def test_get_dataclass_from_config_ignores_extra_keys(self):
tc = SampleObjectTyped._get_dataclass_from_config(
{"a": 1, "b": 2, "c": "hi", "_class_name": "Foo", "extra": 99}
)
assert tc.a == 1
assert not hasattr(tc, "_class_name")
assert not hasattr(tc, "extra")
def test_get_dataclass_from_config_respects_ignore_for_config(self):
tc = SampleObjectWithIgnore._get_dataclass_from_config({"a": 5})
assert not hasattr(tc, "secret")
assert tc.a == 5
def test_get_dataclass_from_config_works_for_scheduler(self):
scheduler = DDIMScheduler()
tc = DDIMScheduler._get_dataclass_from_config(dict(scheduler.config))
assert dataclasses.is_dataclass(tc)
assert type(tc).__name__ == "DDIMSchedulerConfig"
assert tc.num_train_timesteps == scheduler.config.num_train_timesteps
def test_get_dataclass_from_config_different_values(self):
tc1 = SampleObjectTyped._get_dataclass_from_config({"a": 1, "b": 2, "c": "x"})
tc2 = SampleObjectTyped._get_dataclass_from_config({"a": 9, "b": 8, "c": "y"})
assert tc1.a == 1
assert tc2.a == 9

View File

View File

@@ -0,0 +1,172 @@
# Copyright 2025 The HuggingFace Team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import torch
from transformers import AutoConfig, AutoTokenizer, T5EncoderModel
from diffusers import AutoencoderKLWan, HeliosPipeline, HeliosScheduler, HeliosTransformer3DModel
from ...testing_utils import (
backend_empty_cache,
enable_full_determinism,
require_torch_accelerator,
slow,
torch_device,
)
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
enable_full_determinism()
class HeliosPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HeliosPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
required_optional_params = frozenset(
[
"num_inference_steps",
"generator",
"latents",
"return_dict",
"callback_on_step_end",
"callback_on_step_end_tensor_inputs",
]
)
test_xformers_attention = False
supports_dduf = False
def get_dummy_components(self):
torch.manual_seed(0)
vae = AutoencoderKLWan(
base_dim=3,
z_dim=16,
dim_mult=[1, 1, 1, 1],
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
torch.manual_seed(0)
scheduler = HeliosScheduler(stage_range=[0, 1], stages=1, use_dynamic_shifting=True)
config = AutoConfig.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel(config)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
torch.manual_seed(0)
transformer = HeliosTransformer3DModel(
patch_size=(1, 2, 2),
num_attention_heads=2,
attention_head_dim=12,
in_channels=16,
out_channels=16,
text_dim=32,
freq_dim=256,
ffn_dim=32,
num_layers=2,
cross_attn_norm=True,
qk_norm="rms_norm_across_heads",
rope_dim=(4, 4, 4),
has_multi_term_memory_patch=True,
guidance_cross_attn=True,
zero_history_timestep=True,
is_amplify_history=False,
)
components = {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"prompt": "dance monkey",
"negative_prompt": "negative",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 16,
"width": 16,
"num_frames": 9,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
video = pipe(**inputs).frames
generated_video = video[0]
self.assertEqual(generated_video.shape, (33, 3, 16, 16))
# fmt: off
expected_slice = torch.tensor([0.4529, 0.4527, 0.4499, 0.4542, 0.4528, 0.4524, 0.4531, 0.4534, 0.5328,
0.5340, 0.5012, 0.5135, 0.5322, 0.5203, 0.5144, 0.5101])
# fmt: on
generated_slice = generated_video.flatten()
generated_slice = torch.cat([generated_slice[:8], generated_slice[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-3))
# Override to set a more lenient max diff threshold.
def test_save_load_float16(self):
super().test_save_load_float16(expected_max_diff=0.03)
@unittest.skip("Test not supported")
def test_attention_slicing_forward_pass(self):
pass
@unittest.skip("Optional components not applicable for Helios")
def test_save_load_optional_components(self):
pass
@slow
@require_torch_accelerator
class HeliosPipelineIntegrationTests(unittest.TestCase):
prompt = "A painting of a squirrel eating a burger."
def setUp(self):
super().setUp()
gc.collect()
backend_empty_cache(torch_device)
def tearDown(self):
super().tearDown()
gc.collect()
backend_empty_cache(torch_device)
@unittest.skip("TODO: test needs to be implemented")
def test_helios(self):
pass

View File

@@ -74,7 +74,7 @@ if is_torchao_available():
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
@@ -132,7 +132,7 @@ class TorchAoConfigTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoTest(unittest.TestCase):
def tearDown(self):
gc.collect()
@@ -587,7 +587,7 @@ class TorchAoTest(unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoSerializationTest(unittest.TestCase):
model_name = "hf-internal-testing/tiny-flux-pipe"
@@ -698,23 +698,22 @@ class TorchAoSerializationTest(unittest.TestCase):
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device)
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
@property
def quantization_config(self):
return PipelineQuantizationConfig(
quant_mapping={
"transformer": TorchAoConfig(quant_type="int8_weight_only"),
},
quant_mapping={"transformer": TorchAoConfig(Int8WeightOnlyConfig())},
)
@unittest.skip(
"Changing the device of AQT tensor with module._apply (called from doing module.to() in accelerate) does not work "
"when compiling."
)
def test_torch_compile_with_cpu_offload(self):
pipe = self._init_pipeline(self.quantization_config, torch.bfloat16)
pipe.enable_model_cpu_offload()
# No compilation because it fails with:
# RuntimeError: _apply(): Couldn't swap Linear.weight
super().test_torch_compile_with_cpu_offload()
# small resolutions to ensure speedy execution.
pipe("a dog", num_inference_steps=2, max_sequence_length=16, height=256, width=256)
@parameterized.expand([False, True])
@unittest.skip(
@@ -745,7 +744,7 @@ class TorchAoCompileTest(QuantCompileTests, unittest.TestCase):
# Slices for these tests have been obtained on our aws-g6e-xlarge-plus runners
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoTests(unittest.TestCase):
@@ -907,7 +906,7 @@ class SlowTorchAoTests(unittest.TestCase):
@require_torch
@require_torch_accelerator
@require_torchao_version_greater_or_equal("0.7.0")
@require_torchao_version_greater_or_equal("0.14.0")
@slow
@nightly
class SlowTorchAoPreserializedModelTests(unittest.TestCase):