Compare commits

...

23 Commits

Author SHA1 Message Date
DN6
6ec4dee783 update 2026-03-26 15:25:08 +05:30
DN6
50015c966a update 2026-03-26 15:21:29 +05:30
Kashif Rasul
762ae059fa [LLADA2] documentation fixes (#13333)
documentation fixes
2026-03-25 17:49:31 +05:30
Kashif Rasul
5d207e756e [Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion

Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
- BlockRefinementPipeline: block-wise iterative refinement with confidence-based
  token commitment, supporting editing threshold for LLaDA2.1 models
- LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults
- DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p,
  temperature) and prompt/prefix helpers
- compute_confidence_aware_loss: CAP-style training loss
- Examples: sampling scripts for LLaDA2 and block refinement, training scripts
  with Qwen causal LM
- Docs and tests included

* feat: add BlockRefinementScheduler for commit-by-confidence scheduling

Extract the confidence-based token commit logic from BlockRefinementPipeline
into a dedicated BlockRefinementScheduler, following diffusers conventions.

The scheduler owns:
- Transfer schedule computation (get_num_transfer_tokens)
- Timestep management (set_timesteps)
- Step logic: confidence-based mask-filling and optional token editing

The pipeline now delegates scheduling to self.scheduler.step() and accepts
a scheduler parameter in __init__.

* test: add unit tests for BlockRefinementScheduler

12 tests covering set_timesteps, get_num_transfer_tokens, step logic
(confidence-based commits, threshold behavior, editing, prompt masking,
batched inputs, tuple output).

* docs: add toctree entries and standalone scheduler doc page

- Add BlockRefinement and LLaDA2 to docs sidebar navigation
- Add BlockRefinementScheduler to schedulers sidebar navigation
- Move scheduler autodoc to its own page under api/schedulers/

* feat: add --revision flag and fix dtype deprecation in sample_llada2.py

- Add --revision argument for loading model revisions from the Hub
- Replace deprecated torch_dtype with dtype for transformers 5.x compat

* fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat

LLaDA2 models expect a boolean-style (1/0) attention mask, not an
additive (0/-inf) mask. The model internally converts to additive,
so passing 0/-inf caused double-masking and gibberish output.

* refactor: consolidate training scripts into single train_block_refinement.py

- Remove toy train_block_refinement_cap.py (self-contained demo with tiny model)
- Rename train_block_refinement_qwen_cap.py to train_block_refinement.py
  (already works with any causal LM via AutoModelForCausalLM)
- Fix torch_dtype deprecation and update README with correct script names

* fix formatting

* docs: improve LLaDA2 and BlockRefinement documentation

- Add usage examples with real model IDs and working code
- Add recommended parameters table for LLaDA2.1 quality/speed modes
- Note that editing is LLaDA2.1-only (not for LLaDA2.0 models)
- Remove misleading config defaults section from BlockRefinement docs

* feat: set LLaDA2Pipeline defaults to recommended model parameters

- threshold: 0.95 -> 0.7 (quality mode)
- max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0)
- eos_early_stop: False -> True (stop at EOS token)

block_length=32, steps=32, temperature=0.0 were already correct.
editing_threshold remains None (users enable for LLaDA2.1 models).

* feat: default editing_threshold=0.5 for LLaDA2.1 quality mode

LLaDA2.1 is the current generation. Users with LLaDA2.0 models can
disable editing by passing editing_threshold=None.

* fix: align sampling utilities with official LLaDA2 implementation

- top_p filtering: add shift-right to preserve at least one token above
  threshold (matches official code line 1210)
- temperature ordering: apply scaling before top-k/top-p filtering so
  filtering operates on scaled logits (matches official code lines 1232-1235)
- greedy branch: return argmax directly when temperature=0 without
  filtering (matches official code lines 1226-1230)

* refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids

LLaDA2Pipeline._prepare_prompt_ids was a near-copy of
DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate
and call the mixin method directly. Also simplify _extract_input_ids
since we always pass return_dict=True.

* formatting

* fix: replace deprecated torch_dtype with dtype in examples and docstrings

- Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID
- Fix sample_block_refinement.py to use dtype=

* remove BlockRefinementPipeline

* cleanup

* fix readme

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* removed DiscreteDiffusionPipelineMixin

* add support for 2d masks for flash attn

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix issues from review

* added tests

* formatting

* add check_eos_finished to scheduler

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix renaming issues and types

* remove duplicate check

* Update docs/source/en/api/pipelines/llada2.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-25 16:17:50 +05:30
Sayak Paul
e358ddcce6 fix to device and to dtype tests. (#13323) 2026-03-25 11:47:02 +05:30
Sayak Paul
153fcbc5a8 fix klein lora loading. (#13313) 2026-03-25 07:51:35 +05:30
Beinsezii
da6718f080 ZImageTransformer2D: Only build attention mask if seqlens are not equal (#12955) 2026-03-24 06:06:50 -10:00
Alexey Kirillov
832676d35e Use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING (#13320)
refactor: use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING

Co-authored-by: Alexkkir <alexkkir@gmail.coom>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-24 17:49:50 +05:30
Dhruv Nair
7bbd96da5d [CI] Update fetching pipelines for latest HF Hub Version (#13322)
update
2026-03-24 16:42:32 +05:30
Dhruv Nair
62777fa819 Fix unguarded torchvision import in Cosmos (#13321)
update
2026-03-24 16:00:24 +05:30
Sayak Paul
f1fd515257 [tests] fix lora logging tests for models. (#13318)
* fix lora logging tests for models.

* make style
2026-03-24 15:48:03 +05:30
Cheung Ka Wai
afdda57f61 Fix the attention mask in ulysses SP for QwenImage (#13278)
* fix mask in SP

* change the modification to qwen specific

* drop xfail since qwen-image mask is fixed

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-24 02:12:50 -07:00
YangKai0616
5fc2bd2c8f Stabilize low-precision custom autoencoder RMS normalization (#13316)
* Stabilize low-precision custom autoencoder RMS normalization

* Add fp8/4

* Apply style fixes

---------

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-24 02:00:05 -07:00
Sayak Paul
6350a7690a [chore] properly deprecate src.diffusers.utils.testing_utils. (#13314)
properly deprecate src.diffusers.utils.testing_utils.
2026-03-24 10:54:35 +05:30
Cheung Ka Wai
9d4c9dcf21 change QwenImageTransformer UT to batch inputs (#13312)
* UT expands to batch inputs

* update according to suggestion

* update according to suggestion 2

* fix CI

* update according to suggestion 3

* clean line
2026-03-24 08:56:40 +05:30
ddavidchick
ef309a1bb0 Add KVAE 1.0 (#13033)
* add kvae2d

* add kvae3d video

* add docs for kvae2d and kvae3d video

* style fixes

* fix kvae3d docs

* fix normalzation

* fix kvae video for code style

* fix kvae video

* kvae minor fixes

* add gradient ckpting for kvaes

* get rid of inplace ops kvae video

* add tests for KVAEs

* kvae2d normalization style change

* kvaes fix style

* update dummy_pt_objects test for kvaes

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
2026-03-23 12:56:49 -10:00
Charles
b9761ce5a2 [export] Add export-safe LRU cache helper (#13290)
* [core] Add export-safe LRU cache helper

* torch version check!

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-23 18:10:07 +05:30
Dhruv Nair
52558b45d8 [CI] Flux2 Model Test Refactor (#13071)
* update

* update

* update

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-23 16:56:08 +05:30
Sayak Paul
c02c17c6ee [tests] test load_components in modular (#13245)
* test load_components.

* fix

* fix

* u[

* up
2026-03-21 09:41:48 +05:30
Sayak Paul
a9855c4204 [tests] fix audioldm2 tests. (#13293)
fix audioldm2 tests.
2026-03-20 20:53:21 +05:30
Sayak Paul
0b35834351 [core] fa4 support. (#13280)
* start fa4 support.

* up

* specify minimum version
2026-03-20 17:28:09 +05:30
Sayak Paul
522b523e40 [ci] hoping to fix is_flaky with wanvace. (#13294)
* hoping to fix is_flaky with wanvace.

* revert changes in src/diffusers/utils/testing_utils.py and propagate them to tests/testing_utils.py.

* up
2026-03-20 16:02:16 +05:30
Dhruv Nair
e9b9f25f67 [CI] Update transformer version in release tests (#13296)
update
2026-03-20 11:40:06 +05:30
56 changed files with 5929 additions and 295 deletions

View File

@@ -4,6 +4,7 @@
name: (Release) Fast GPU Tests on main
on:
workflow_dispatch:
push:
branches:
- "v*.*.*-release"
@@ -33,6 +34,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -74,6 +76,7 @@ jobs:
run: |
uv pip install -e ".[quality]"
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -125,6 +128,7 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -175,6 +179,7 @@ jobs:
uv pip install -e ".[quality]"
uv pip install peft@git+https://github.com/huggingface/peft.git
uv pip uninstall accelerate && uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
@@ -232,6 +237,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -274,6 +280,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |
python utils/print_env.py
@@ -316,6 +323,7 @@ jobs:
- name: Install dependencies
run: |
uv pip install -e ".[quality,training]"
uv pip uninstall transformers huggingface_hub && uv pip install --prerelease allow -U transformers@git+https://github.com/huggingface/transformers.git
- name: Environment
run: |

View File

@@ -446,6 +446,10 @@
title: AutoencoderKLHunyuanVideo
- local: api/models/autoencoder_kl_hunyuan_video15
title: AutoencoderKLHunyuanVideo15
- local: api/models/autoencoder_kl_kvae
title: AutoencoderKLKVAE
- local: api/models/autoencoder_kl_kvae_video
title: AutoencoderKLKVAEVideo
- local: api/models/autoencoderkl_audio_ltx_2
title: AutoencoderKLLTX2Audio
- local: api/models/autoencoderkl_ltx_2
@@ -666,6 +670,10 @@
- local: api/pipelines/z_image
title: Z-Image
title: Image
- sections:
- local: api/pipelines/llada2
title: LLaDA2
title: Text
- sections:
- local: api/pipelines/allegro
title: Allegro
@@ -714,6 +722,8 @@
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/block_refinement
title: BlockRefinementScheduler
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
- local: api/schedulers/ddim_cogvideox

View File

@@ -0,0 +1,32 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAE
The 2D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAE
vae = AutoencoderKLKVAE.from_pretrained("kandinskylab/KVAE-2D-1.0", subfolder="diffusers", torch_dtype=torch.bfloat16)
```
## AutoencoderKLKVAE
[[autodoc]] AutoencoderKLKVAE
- decode
- all

View File

@@ -0,0 +1,33 @@
<!-- Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. -->
# AutoencoderKLKVAEVideo
The 3D variational autoencoder (VAE) model with KL loss.
The model can be loaded with the following code snippet.
```python
import torch
from diffusers import AutoencoderKLKVAEVideo
vae = AutoencoderKLKVAEVideo.from_pretrained("kandinskylab/KVAE-3D-1.0", subfolder="diffusers", torch_dtype=torch.float16)
```
## AutoencoderKLKVAEVideo
[[autodoc]] AutoencoderKLKVAEVideo
- decode
- all

View File

@@ -0,0 +1,90 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
steps.
## Usage
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
model_id = "inclusionAI/LLaDA2.1-mini"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
output = pipe(
prompt="Write a short poem about the ocean.",
gen_length=256,
block_length=32,
num_inference_steps=32,
threshold=0.7,
editing_threshold=0.5,
max_post_steps=16,
temperature=0.0,
)
print(output.texts[0])
```
## Callbacks
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
window.
```py
def on_step_end(pipe, step, timestep, callback_kwargs):
block_x = callback_kwargs["block_x"]
# Inspect or modify `block_x` here.
return {"block_x": block_x}
out = pipe(
prompt="Write a short poem.",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["block_x"],
)
```
## Recommended parameters
LLaDA2.1 models support two modes:
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|------|-------------|---------------------|------------------|
| Quality | 0.7 | 0.5 | 16 |
| Speed | 0.5 | `None` | 16 |
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
## LLaDA2Pipeline
[[autodoc]] LLaDA2Pipeline
- all
- __call__
## LLaDA2PipelineOutput
[[autodoc]] pipelines.LLaDA2PipelineOutput

View File

@@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [LLaDA2](llada2) | text2text |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |

View File

@@ -0,0 +1,25 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BlockRefinementScheduler
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
token with high confidence.
This scheduler is used by [`LLaDA2Pipeline`].
## BlockRefinementScheduler
[[autodoc]] BlockRefinementScheduler
## BlockRefinementSchedulerOutput
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput

View File

@@ -143,6 +143,7 @@ Refer to the table below for a complete list of available attention backends and
| `flash_varlen` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention |
| `flash_varlen_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention from kernels |
| `aiter` | [AI Tensor Engine for ROCm](https://github.com/ROCm/aiter) | FlashAttention for AMD ROCm |
| `flash_4_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-4 |
| `_flash_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 |
| `_flash_varlen_3` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | Variable length FlashAttention-3 |
| `_flash_3_hub` | [FlashAttention](https://github.com/Dao-AILab/flash-attention) | FlashAttention-3 from kernels |

View File

@@ -0,0 +1,50 @@
# Discrete Token Diffusion (Experimental)
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
## LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
### Train
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--text_column text \
--output_dir llada2-output \
--max_train_steps 1000 \
--prompt_length 32 \
--block_length 32 \
--lambda_conf 2.0 \
--conf_temperature 0.5
```
If you don't want to download a dataset, you can use random-token data:
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--output_dir llada2-output \
--use_dummy_data \
--num_dummy_samples 2048
```
### Sample
```bash
python examples/discrete_diffusion/sample_llada2.py \
--model_id inclusionAI/LLaDA2.1-mini \
--prompt "Write a short poem about the ocean." \
--gen_length 256 \
--num_inference_steps 32 \
--threshold 0.7 \
--editing_threshold 0.5 \
--max_post_steps 16 \
--use_chat_template \
--add_generation_prompt
```

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample script for LLaDA2-style discrete diffusion text generation.
This script demonstrates how to use the LLaDA2Pipeline for text generation
using block-wise iterative refinement.
Example usage:
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
from diffusers.hooks import apply_group_offloading
def main():
parser = argparse.ArgumentParser(
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
)
parser.add_argument(
"--model_id",
type=str,
default="inclusionAI/LLaDA2.0-mini",
help="HuggingFace model ID or path to local model.",
)
parser.add_argument(
"--prompt",
type=str,
default="Why does Camus think that Sisyphus is happy?",
help="Text prompt to generate from.",
)
parser.add_argument(
"--gen_length",
type=int,
default=2048,
help="Number of tokens to generate.",
)
parser.add_argument(
"--block_length",
type=int,
default=32,
help="Size of each generation block.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=32,
help="Number of refinement steps per block.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (0.0 for greedy).",
)
parser.add_argument(
"--top_p",
type=float,
default=None,
help="Nucleus sampling probability threshold.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Top-k sampling parameter.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.95,
help="Confidence threshold for committing tokens.",
)
parser.add_argument(
"--editing_threshold",
type=float,
default=None,
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
)
parser.add_argument(
"--max_post_steps",
type=int,
default=0,
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
)
parser.add_argument(
"--sampling_method",
type=str,
default="multinomial",
choices=["auto", "greedy", "multinomial"],
help="Sampling method for block refinement.",
)
parser.add_argument(
"--eos_early_stop",
action="store_true",
help="Stop generation early when EOS token is generated.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--offload",
type=str,
default=None,
choices=["group", "sequential"],
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
)
args = parser.parse_args()
# Parse dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
# Load model with appropriate memory settings based on offload strategy
if args.offload == "group":
# For group offloading, load to CPU first then apply hooks
print("Using group offloading for memory efficiency...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Apply group offloading with CUDA streams for better performance
onload_device = torch.device(args.device)
offload_device = torch.device("cpu")
apply_group_offloading(
model,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
elif args.offload == "sequential":
# For sequential offloading, load to CPU first
print("Using sequential CPU offloading (slower but lower memory)...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Sequential offloading will be applied via pipeline
else:
# Default: use device_map="auto" for automatic memory management
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
revision=args.revision,
)
model.eval()
# Create pipeline
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
# Apply sequential CPU offload if requested
if args.offload == "sequential":
pipe.enable_sequential_cpu_offload()
# Set up generator for reproducibility
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device).manual_seed(args.seed)
print(f"\nPrompt: {args.prompt}")
print(
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
)
print("-" * 50)
# Generate
output = pipe(
prompt=args.prompt,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
gen_length=args.gen_length,
block_length=args.block_length,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
threshold=args.threshold,
editing_threshold=args.editing_threshold,
max_post_steps=args.max_post_steps,
sampling_method=args.sampling_method,
eos_early_stop=args.eos_early_stop,
generator=generator,
)
print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,321 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
from dataclasses import asdict, dataclass
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
from diffusers import BlockRefinementScheduler
from diffusers.training_utils import compute_confidence_aware_loss
logger = get_logger(__name__)
@dataclass
class TrainConfig:
model_name_or_path: str
dataset_name: str
dataset_config_name: Optional[str]
text_column: str
cache_dir: Optional[str]
use_dummy_data: bool
num_dummy_samples: int
output_dir: str
seed: int
max_train_steps: int
checkpointing_steps: int
logging_steps: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
learning_rate: float
weight_decay: float
lr_scheduler: str
lr_warmup_steps: int
max_length: int
prompt_length: int
block_length: int
lambda_conf: float
conf_temperature: float
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
parser.add_argument("--text_column", type=str, default="text")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
parser.add_argument("--num_dummy_samples", type=int, default=2048)
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_train_steps", type=int, default=1000)
parser.add_argument("--checkpointing_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument(
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
)
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--prompt_length", type=int, default=32)
parser.add_argument("--block_length", type=int, default=32)
parser.add_argument("--lambda_conf", type=float, default=2.0)
parser.add_argument("--conf_temperature", type=float, default=0.5)
args = parser.parse_args()
return TrainConfig(**vars(args))
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
texts = examples[text_column]
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
class RandomTokenDataset(torch.utils.data.Dataset):
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
self.num_samples = int(num_samples)
self.seq_len = int(seq_len)
self.vocab_size = int(vocab_size)
self.pad_token_id = int(pad_token_id)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
del idx
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def main():
cfg = parse_args()
if cfg.prompt_length >= cfg.max_length:
raise ValueError("`prompt_length` must be < `max_length`.")
if cfg.block_length <= 0:
raise ValueError("`block_length` must be > 0.")
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
project_config=project_config,
)
if accelerator.is_main_process:
os.makedirs(cfg.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
set_seed(cfg.seed)
logger.info("Training configuration: %s", asdict(cfg))
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.mask_token_id is None:
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
model.resize_token_embeddings(len(tokenizer))
if load_dtype == torch.float32:
model.to(dtype=torch.float32)
mask_token_id = int(tokenizer.mask_token_id)
if cfg.use_dummy_data:
dataset = RandomTokenDataset(
num_samples=cfg.num_dummy_samples,
seq_len=cfg.max_length,
vocab_size=len(tokenizer),
pad_token_id=int(tokenizer.pad_token_id),
)
train_dataloader = DataLoader(
dataset,
shuffle=True,
batch_size=cfg.per_device_train_batch_size,
drop_last=True,
)
else:
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
if "train" not in raw_datasets:
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
with accelerator.main_process_first():
tokenized = raw_datasets["train"].map(
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
batched=True,
remove_columns=raw_datasets["train"].column_names,
desc="Tokenizing",
)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
train_dataloader = DataLoader(
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
global_step = 0
model.train()
for _epoch in range(num_train_epochs):
for batch in train_dataloader:
with accelerator.accumulate(model):
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=cfg.prompt_length,
block_length=cfg.block_length,
mask_token_id=mask_token_id,
generator=gen,
)
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
logits_rev = model(
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
).logits
logits = logits.clone()
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
logits_rev = logits_rev.clone()
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
valid = attention_mask.to(dtype=torch.bool)
masked = masked & valid
masked_rev = masked_rev & valid
labels = input_ids.clone()
labels[~masked] = -100
labels_rev = input_ids.clone()
labels_rev[~masked_rev] = -100
weights = masked.to(dtype=logits.dtype)
weights_rev = masked_rev.to(dtype=logits.dtype)
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits,
labels,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights,
)
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
logits_rev,
labels_rev,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights_rev,
)
total_loss = loss + loss_rev
accelerator.backward(total_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
logger.info(
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
global_step,
total_loss.item(),
(loss_sft + loss_sft_rev).item(),
(loss_conf + loss_conf_rev).item(),
lr_scheduler.get_last_lr()[0],
)
print(
f"step={global_step} loss={total_loss.item():.4f} "
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
)
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
tokenizer.save_pretrained(save_dir)
if global_step >= cfg.max_train_steps:
break
if global_step >= cfg.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
final_dir = os.path.join(cfg.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
tokenizer.save_pretrained(final_dir)
logger.info("Done.")
if __name__ == "__main__":
main()

View File

@@ -193,6 +193,8 @@ else:
"AutoencoderKLHunyuanImageRefiner",
"AutoencoderKLHunyuanVideo",
"AutoencoderKLHunyuanVideo15",
"AutoencoderKLKVAE",
"AutoencoderKLKVAEVideo",
"AutoencoderKLLTX2Audio",
"AutoencoderKLLTX2Video",
"AutoencoderKLLTXVideo",
@@ -342,6 +344,8 @@ else:
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
"BlockRefinementScheduler",
"BlockRefinementSchedulerOutput",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
@@ -578,6 +582,8 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LLaDA2Pipeline",
"LLaDA2PipelineOutput",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
@@ -975,6 +981,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,
@@ -1120,6 +1128,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
BlockRefinementScheduler,
BlockRefinementSchedulerOutput,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
@@ -1335,6 +1345,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LLaDA2Pipeline,
LLaDA2PipelineOutput,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import functools
import inspect
from dataclasses import dataclass
from typing import Type
@@ -32,7 +31,7 @@ from ..models._modeling_parallel import (
gather_size_by_comm,
)
from ..utils import get_logger
from ..utils.torch_utils import maybe_allow_in_graph, unwrap_module
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph, unwrap_module
from .hooks import HookRegistry, ModelHook
@@ -327,7 +326,7 @@ class PartitionAnythingSharder:
return tensor
@functools.lru_cache(maxsize=64)
@lru_cache_unless_export(maxsize=64)
def _fill_gather_shapes(shape: tuple[int], gather_dims: tuple[int], dim: int, world_size: int) -> list[list[int]]:
gather_shapes = []
for i in range(world_size):

View File

@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check if upweight is sparse
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
# Detect number of blocks from keys
num_double_layers = 0
num_single_layers = 0
for key in state_dict.keys():
if key.startswith("lora_unet_double_blocks_"):
block_idx = int(key.split("_")[4])
num_double_layers = max(num_double_layers, block_idx + 1)
elif key.startswith("lora_unet_single_blocks_"):
block_idx = int(key.split("_")[4])
num_single_layers = max(num_single_layers, block_idx + 1)
ait_sd = {}
for i in range(num_double_layers):
# Attention projections
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
# MLP layers (Flux2 uses ff.linear_in/linear_out)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.linear_out",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
)
for i in range(num_single_layers):
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
)
# Single blocks: linear2 -> attn.to_out
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.attn.to_out",
)
# Handle optional extra keys
extra_mappings = {
"lora_unet_img_in": "transformer.x_embedder",
"lora_unet_txt_in": "transformer.context_embedder",
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
"lora_unet_final_layer_linear": "transformer.proj_out",
}
for sds_key, ait_key in extra_mappings.items():
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
remaining_keys = list(state_dict.keys())
if remaining_keys:
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
return ait_sd
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.

View File

@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux2_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5673,6 +5674,13 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

View File

@@ -15,6 +15,7 @@
import inspect
import json
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Literal
@@ -44,33 +45,13 @@ from .unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
},
)
class PeftAdapterMixin:

View File

@@ -40,6 +40,8 @@ if is_torch_available():
_import_structure["autoencoders.autoencoder_kl_hunyuanimage"] = ["AutoencoderKLHunyuanImage"]
_import_structure["autoencoders.autoencoder_kl_hunyuanimage_refiner"] = ["AutoencoderKLHunyuanImageRefiner"]
_import_structure["autoencoders.autoencoder_kl_hunyuanvideo15"] = ["AutoencoderKLHunyuanVideo15"]
_import_structure["autoencoders.autoencoder_kl_kvae"] = ["AutoencoderKLKVAE"]
_import_structure["autoencoders.autoencoder_kl_kvae_video"] = ["AutoencoderKLKVAEVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
_import_structure["autoencoders.autoencoder_kl_ltx2"] = ["AutoencoderKLLTX2Video"]
_import_structure["autoencoders.autoencoder_kl_ltx2_audio"] = ["AutoencoderKLLTX2Audio"]
@@ -161,6 +163,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
AutoencoderKLHunyuanImageRefiner,
AutoencoderKLHunyuanVideo,
AutoencoderKLHunyuanVideo15,
AutoencoderKLKVAE,
AutoencoderKLKVAEVideo,
AutoencoderKLLTX2Audio,
AutoencoderKLLTX2Video,
AutoencoderKLLTXVideo,

View File

@@ -49,7 +49,7 @@ from ..utils import (
is_xformers_version,
)
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
from ..utils.torch_utils import maybe_allow_in_graph
from ..utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from ._modeling_parallel import gather_size_by_comm
@@ -229,6 +229,7 @@ class AttentionBackendName(str, Enum):
FLASH_HUB = "flash_hub"
FLASH_VARLEN = "flash_varlen"
FLASH_VARLEN_HUB = "flash_varlen_hub"
FLASH_4_HUB = "flash_4_hub"
_FLASH_3 = "_flash_3"
_FLASH_VARLEN_3 = "_flash_varlen_3"
_FLASH_3_HUB = "_flash_3_hub"
@@ -358,6 +359,11 @@ _HUB_KERNELS_REGISTRY: dict["AttentionBackendName", _HubKernelConfig] = {
function_attr="sageattn",
version=1,
),
AttentionBackendName.FLASH_4_HUB: _HubKernelConfig(
repo_id="kernels-staging/flash-attn4",
function_attr="flash_attn_func",
version=0,
),
}
@@ -521,6 +527,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
AttentionBackendName._FLASH_3_HUB,
AttentionBackendName._FLASH_3_VARLEN_HUB,
AttentionBackendName.SAGE_HUB,
AttentionBackendName.FLASH_4_HUB,
]:
if not is_kernels_available():
raise RuntimeError(
@@ -531,6 +538,11 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12. Please update with `pip install -U kernels`."
)
if backend == AttentionBackendName.FLASH_4_HUB and not is_kernels_available(">=", "0.12.3"):
raise RuntimeError(
f"Backend '{backend.value}' needs to be used with a `kernels` version of at least 0.12.3. Please update with `pip install -U kernels`."
)
elif backend == AttentionBackendName.AITER:
if not _CAN_USE_AITER_ATTN:
raise RuntimeError(
@@ -575,7 +587,7 @@ def _check_attention_backend_requirements(backend: AttentionBackendName) -> None
)
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _prepare_for_flash_attn_or_sage_varlen_without_mask(
batch_size: int,
seq_len_q: int,
@@ -2676,6 +2688,37 @@ def _flash_attention_3_varlen_hub(
return (out, lse) if return_lse else out
@_AttentionBackendRegistry.register(
AttentionBackendName.FLASH_4_HUB,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],
supports_context_parallel=False,
)
def _flash_attention_4_hub(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: torch.Tensor | None = None,
scale: float | None = None,
is_causal: bool = False,
return_lse: bool = False,
_parallel_config: "ParallelConfig" | None = None,
) -> torch.Tensor:
if attn_mask is not None:
raise ValueError("`attn_mask` is not supported for flash-attn 4.")
func = _HUB_KERNELS_REGISTRY[AttentionBackendName.FLASH_4_HUB].kernel_fn
out = func(
q=query,
k=key,
v=value,
softmax_scale=scale,
causal=is_causal,
)
if isinstance(out, tuple):
return (out[0], out[1]) if return_lse else out[0]
return out
@_AttentionBackendRegistry.register(
AttentionBackendName._FLASH_VARLEN_3,
constraints=[_check_device, _check_qkv_dtype_bf16_or_fp16, _check_shape],

View File

@@ -9,6 +9,8 @@ from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
from .autoencoder_kl_hunyuanimage import AutoencoderKLHunyuanImage
from .autoencoder_kl_hunyuanimage_refiner import AutoencoderKLHunyuanImageRefiner
from .autoencoder_kl_hunyuanvideo15 import AutoencoderKLHunyuanVideo15
from .autoencoder_kl_kvae import AutoencoderKLKVAE
from .autoencoder_kl_kvae_video import AutoencoderKLKVAEVideo
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
from .autoencoder_kl_ltx2 import AutoencoderKLLTX2Video
from .autoencoder_kl_ltx2_audio import AutoencoderKLLTX2Audio

View File

@@ -87,7 +87,14 @@ class HunyuanImageRefinerRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class HunyuanImageRefinerAttnBlock(nn.Module):

View File

@@ -87,7 +87,14 @@ class HunyuanVideo15RMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class HunyuanVideo15AttnBlock(nn.Module):

View File

@@ -0,0 +1,802 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils.accelerate_utils import apply_forward_hook
from ..activations import get_activation
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
class KVAEResnetBlock2D(nn.Module):
r"""
A Resnet block with optional guidance.
Parameters:
in_channels (`int`): The number of channels in the input.
out_channels (`int`, *optional*, default to `None`):
The number of output channels for the first conv2d layer. If None, same as `in_channels`.
conv_shortcut (`bool`, *optional*, default to `False`):
If `True` and `in_channels` not equal to `out_channels`, add a 3x3 nn.conv2d layer for skip-connection.
temb_channels (`int`, *optional*, default to `512`): The number of channels in timestep embedding.
zq_ch (`int`, *optional*, default to `None`): Guidance channels for normalization.
add_conv (`bool`, *optional*, default to `False`):
If `True` add conv2d layer for normalization.
normalization (`nn.Module`, *optional*, default to `None`): The normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
temb_channels: int = 512,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
self.norm1 = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm1 = KVAEDecoderSpatialNorm2D(in_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = nn.GroupNorm(num_channels=out_channels, num_groups=32, eps=1e-6, affine=True)
else:
self.norm2 = KVAEDecoderSpatialNorm2D(out_channels, zq_channels=zq_ch, add_conv=add_conv)
self.conv2 = nn.Conv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
else:
self.nin_shortcut = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x: torch.Tensor, temb: torch.Tensor, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
h = self.norm1(h)
else:
h = self.norm1(h, zq)
h = self.nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h)
else:
h = self.norm2(h, zq)
h = self.nonlinearity(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class KVAEPXSDownsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
A Downsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The downsampling factor.
"""
super().__init__()
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (bchw)
pxs_interm = self.unshuffle(x)
b, c, h, w = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b, c // self.factor**2, self.factor**2, h, w)
pxs_out = torch.mean(pxs_interm_view, dim=2)
conv_out = self.spatial_conv(x)
# adding it all together
out = conv_out + pxs_out
return self.linear(out)
class KVAEPXSUpsample(nn.Module):
def __init__(self, in_channels: int, factor: int = 2):
r"""
An Upsampling module.
Args:
in_channels (`int`): The number of channels in the input.
factor (`int`, *optional*, default to `2`): The upsampling factor.
"""
super().__init__()
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = nn.Conv2d(
in_channels, in_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode="reflect"
)
self.linear = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
repeated = x.repeat_interleave(self.factor**2, dim=1)
pxs_interm = self.shuffle(repeated)
image_like_ups = F.interpolate(x, scale_factor=2, mode="nearest")
conv_out = self.spatial_conv(image_like_ups)
# adding it all together
out = conv_out + pxs_interm
return self.linear(out)
class KVAEDecoderSpatialNorm2D(nn.Module):
r"""
A 2D normalization module for decoder.
Args:
in_channels (`int`): The number of channels in the input.
zq_channels (`int`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`):
If `True` add conv2d 3x3 layer for guidance in the beginning.
"""
def __init__(
self,
in_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = nn.GroupNorm(num_channels=in_channels, num_groups=32, eps=1e-6, affine=True)
self.add_conv = add_conv
if add_conv:
self.conv = nn.Conv2d(
in_channels=zq_channels,
out_channels=zq_channels,
kernel_size=3,
padding=(1, 1),
padding_mode="replicate",
)
self.conv_y = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
self.conv_b = nn.Conv2d(
in_channels=zq_channels,
out_channels=in_channels,
kernel_size=1,
)
def forward(self, f: torch.Tensor, zq: torch.Tensor) -> torch.Tensor:
f_first = f
f_first_size = f_first.shape[2:]
zq = F.interpolate(zq, size=f_first_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
class KVAEEncoder2D(nn.Module):
r"""
A 2D encoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of output channels.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels for the last block.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
double_z: bool = True,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
if isinstance(num_res_blocks, int):
self.num_res_blocks = [num_res_blocks] * self.num_resolutions
else:
self.num_res_blocks = num_res_blocks
self.nonlinearity = get_activation(act_fn)
self.in_channels = in_channels
self.conv_in = nn.Conv2d(
in_channels=in_channels,
out_channels=self.ch,
kernel_size=3,
padding=(1, 1),
)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks[i_level]):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level < self.num_resolutions - 1:
down.downsample = KVAEPXSDownsample(in_channels=block_in) # mb: bad out channels
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
)
# end
self.norm_out = nn.GroupNorm(num_channels=block_in, num_groups=32, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
in_channels=block_in,
out_channels=2 * z_channels if double_z else z_channels,
kernel_size=3,
padding=(1, 1),
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
# timestep embedding
temb = None
# downsampling
h = self.conv_in(x)
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks[i_level]):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.down[i_level].block[i_block], h, temb)
else:
h = self.down[i_level].block[i_block](h, temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb)
else:
h = self.mid.block_1(h, temb)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class KVAEDecoder2D(nn.Module):
r"""
A 2D decoder module.
Args:
ch (`int`): The base number of channels in multiresolution blocks.
out_ch (`int`): The number of output channels.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
num_res_blocks (`int`): The number of Resnet blocks.
in_channels (`int`): The number of channels in the input.
z_channels (`int`): The number of input channels.
give_pre_end (`bool`, *optional*, default to `false`):
If `True` exit the forward pass early and return the penultimate feature map.
zq_ch (`bool`, *optional*, default to `None`): The number of channels in the guidance.
add_conv (`bool`, *optional*, default to `false`): If `True` add conv2d layer for Resnet normalization layer.
act_fn (`str`, *optional*, default to `"swish"`): The activation function to use.
"""
def __init__(
self,
*,
ch: int,
out_ch: int,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int,
in_channels: int,
z_channels: int,
give_pre_end: bool = False,
zq_ch: Optional[int] = None,
add_conv: bool = False,
act_fn: str = "swish",
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.nonlinearity = get_activation(act_fn)
if zq_ch is None:
zq_ch = z_channels
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = nn.Conv2d(
in_channels=z_channels, out_channels=block_in, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
# middle
self.mid = nn.Module()
self.mid.block_1 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAEResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = KVAEPXSUpsample(in_channels=block_in)
self.up.insert(0, up)
self.norm_out = KVAEDecoderSpatialNorm2D(block_in, zq_ch, add_conv=add_conv) # , gather=gather_norm)
self.conv_out = nn.Conv2d(
in_channels=block_in, out_channels=out_ch, kernel_size=3, padding=(1, 1), padding_mode="replicate"
)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor) -> torch.Tensor:
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
zq = z
h = self.conv_in(z)
# middle
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, zq)
else:
h = self.mid.block_1(h, temb, zq)
h = self.mid.block_2(h, temb, zq)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.up[i_level].block[i_block], h, temb, zq)
else:
h = self.up[i_level].block[i_block](h, temb, zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h, zq)
h = self.nonlinearity(h)
h = self.conv_out(h)
return h
class AutoencoderKLKVAE(ModelMixin, AutoencoderMixin, ConfigMixin):
r"""
A VAE model with KL loss for encoding images into latents and decoding latent representations into images.
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
channels (int, *optional*, defaults to 128): The base number of channels in multiresolution blocks.
num_enc_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in encoder multiresolution layers.
num_dec_blocks (int, *optional*, defaults to 2):
The number of Resnet blocks in decoder multiresolution layers.
z_channels (int, *optional*, defaults to 16): Number of channels in the latent space.
double_z (`bool`, *optional*, defaults to `True`):
Whether to double the number of output channels of encoder.
ch_mult (`Tuple[int, ...]`, *optional*, default to `(1, 2, 4, 8)`):
The channel multipliers in multiresolution blocks.
sample_size (`int`, *optional*, defaults to `1024`): Sample input size.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
channels: int = 128,
num_enc_blocks: int = 2,
num_dec_blocks: int = 2,
z_channels: int = 16,
double_z: bool = True,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
sample_size: int = 1024,
):
super().__init__()
# pass init params to Encoder
self.encoder = KVAEEncoder2D(
in_channels=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_enc_blocks,
z_channels=z_channels,
double_z=double_z,
)
# pass init params to Decoder
self.decoder = KVAEDecoder2D(
out_ch=in_channels,
ch=channels,
ch_mult=ch_mult,
num_res_blocks=num_dec_blocks,
in_channels=None,
z_channels=z_channels,
)
self.use_slicing = False
self.use_tiling = False
# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.ch_mult) - 1)))
self.tile_overlap_factor = 0.25
def _encode(self, x: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, height, width = x.shape
if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size):
return self._tiled_encode(x)
enc = self.encoder(x)
return enc
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
return self.tiled_decode(z, return_dict=return_dict)
dec = self.decoder(z)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
@apply_forward_hook
def decode(
self, z: torch.FloatTensor, return_dict: bool = True, generator=None
) -> Union[DecoderOutput, torch.FloatTensor]:
"""
Decode a batch of images.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z).sample
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
for y in range(blend_extent):
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for x in range(blend_extent):
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b
def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`torch.Tensor`:
The latent representation of the encoded videos.
"""
overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
row_limit = self.tile_latent_min_size - blend_extent
# Split the image into 512x512 tiles and encode them separately.
rows = []
for i in range(0, x.shape[2], overlap_size):
row = []
for j in range(0, x.shape[3], overlap_size):
tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
tile = self.encoder(tile)
row.append(tile)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
enc = torch.cat(result_rows, dim=2)
return enc
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`:
If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
returned.
"""
overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
row_limit = self.tile_sample_min_size - blend_extent
# Split z into overlapping 64x64 tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, z.shape[2], overlap_size):
row = []
for j in range(0, z.shape[3], overlap_size):
tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
decoded = self.decoder(tile)
row.append(decoded)
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent)
result_row.append(tile[:, :, :row_limit, :row_limit])
result_rows.append(torch.cat(result_row, dim=3))
dec = torch.cat(result_rows, dim=2)
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `False`):
Whether to sample from the posterior.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
"""
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -0,0 +1,954 @@
# Copyright 2025 The Kandinsky Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders.single_file_model import FromOriginalModelMixin
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
def nonlinearity(x: torch.Tensor) -> torch.Tensor:
return F.silu(x)
# =============================================================================
# Base layers
# =============================================================================
class KVAESafeConv3d(nn.Conv3d):
r"""
A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM.
"""
def forward(self, input: torch.Tensor, write_to: torch.Tensor = None) -> torch.Tensor:
memory_count = input.numel() * input.element_size() / (10**9)
if memory_count > 3:
kernel_size = self.kernel_size[0]
part_num = math.ceil(memory_count / 2)
input_chunks = torch.chunk(input, part_num, dim=2)
if write_to is None:
output = []
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
output.append(super().forward(z))
return torch.cat(output, dim=2)
else:
time_offset = 0
for i, chunk in enumerate(input_chunks):
if i == 0 or kernel_size == 1:
z = torch.clone(chunk)
else:
z = torch.cat([z[:, :, -kernel_size + 1 :], chunk], dim=2)
z_time = z.size(2) - (kernel_size - 1)
write_to[:, :, time_offset : time_offset + z_time] = super().forward(z)
time_offset += z_time
return write_to
else:
if write_to is None:
return super().forward(input)
else:
write_to[...] = super().forward(input)
return write_to
class KVAECausalConv3d(nn.Module):
r"""
A 3D causal convolution layer.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor) -> torch.Tensor:
padding_3d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad, self.time_pad, 0)
input_padded = F.pad(input, padding_3d, mode="replicate")
return self.conv(input_padded)
class KVAECachedCausalConv3d(nn.Module):
r"""
A 3D causal convolution layer with caching for temporal processing.
"""
def __init__(
self,
chan_in: int,
chan_out: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Tuple[int, int, int] = (1, 1, 1),
dilation: Tuple[int, int, int] = (1, 1, 1),
**kwargs,
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.height_pad = height_kernel_size // 2
self.width_pad = width_kernel_size // 2
self.time_pad = time_kernel_size - 1
self.time_kernel_size = time_kernel_size
self.stride = stride
self.conv = KVAESafeConv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
def forward(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
t_stride = self.stride[0]
padding_3d = (self.height_pad, self.height_pad, self.width_pad, self.width_pad, 0, 0)
input_parallel = F.pad(input, padding_3d, mode="replicate")
if cache["padding"] is None:
first_frame = input_parallel[:, :, :1]
time_pad_shape = list(first_frame.shape)
time_pad_shape[2] = self.time_pad
padding = first_frame.expand(time_pad_shape)
else:
padding = cache["padding"]
out_size = list(input.shape)
out_size[1] = self.conv.out_channels
if t_stride == 2:
out_size[2] = (input.size(2) + 1) // 2
output = torch.empty(tuple(out_size), dtype=input.dtype, device=input.device)
offset_out = math.ceil(padding.size(2) / t_stride)
offset_in = offset_out * t_stride - padding.size(2)
if offset_out > 0:
padding_poisoned = torch.cat(
[padding, input_parallel[:, :, : offset_in + self.time_kernel_size - t_stride]], dim=2
)
output[:, :, :offset_out] = self.conv(padding_poisoned)
if offset_out < output.size(2):
output[:, :, offset_out:] = self.conv(input_parallel[:, :, offset_in:])
pad_offset = (
offset_in
+ t_stride * math.trunc((input_parallel.size(2) - offset_in - self.time_kernel_size) / t_stride)
+ t_stride
)
cache["padding"] = torch.clone(input_parallel[:, :, pad_offset:])
return output
class KVAECachedGroupNorm(nn.Module):
r"""
GroupNorm with caching support for temporal processing.
"""
def __init__(self, in_channels: int):
super().__init__()
self.norm_layer = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
def forward(self, x: torch.Tensor, cache: Dict = None) -> torch.Tensor:
out = self.norm_layer(x)
if cache is not None and cache.get("mean") is None and cache.get("var") is None:
cache["mean"] = 1
cache["var"] = 1
return out
# =============================================================================
# Cached layers
# =============================================================================
class KVAECachedSpatialNorm3D(nn.Module):
r"""
Spatially conditioned normalization for decoder with caching.
"""
def __init__(
self,
f_channels: int,
zq_channels: int,
add_conv: bool = False,
):
super().__init__()
self.norm_layer = KVAECachedGroupNorm(f_channels)
self.add_conv = add_conv
if add_conv:
self.conv = KVAECachedCausalConv3d(chan_in=zq_channels, chan_out=zq_channels, kernel_size=3)
self.conv_y = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
self.conv_b = KVAESafeConv3d(zq_channels, f_channels, kernel_size=1)
def forward(self, f: torch.Tensor, zq: torch.Tensor, cache: Dict) -> torch.Tensor:
if cache["norm"].get("mean") is None and cache["norm"].get("var") is None:
f_first, f_rest = f[:, :, :1], f[:, :, 1:]
f_first_size, f_rest_size = f_first.shape[-3:], f_rest.shape[-3:]
zq_first, zq_rest = zq[:, :, :1], zq[:, :, 1:]
zq_first = F.interpolate(zq_first, size=f_first_size, mode="nearest")
if zq.size(2) > 1:
zq_rest_splits = torch.split(zq_rest, 32, dim=1)
interpolated_splits = [
F.interpolate(split, size=f_rest_size, mode="nearest") for split in zq_rest_splits
]
zq_rest = torch.cat(interpolated_splits, dim=1)
zq = torch.cat([zq_first, zq_rest], dim=2)
else:
zq = zq_first
else:
f_size = f.shape[-3:]
zq_splits = torch.split(zq, 32, dim=1)
interpolated_splits = [F.interpolate(split, size=f_size, mode="nearest") for split in zq_splits]
zq = torch.cat(interpolated_splits, dim=1)
if self.add_conv:
zq = self.conv(zq, cache["add_conv"])
norm_f = self.norm_layer(f, cache["norm"])
norm_f = norm_f * self.conv_y(zq)
norm_f = norm_f + self.conv_b(zq)
return norm_f
class KVAECachedResnetBlock3D(nn.Module):
r"""
A 3D ResNet block with caching.
"""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 0,
zq_ch: Optional[int] = None,
add_conv: bool = False,
gather_norm: bool = False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if zq_ch is None:
self.norm1 = KVAECachedGroupNorm(in_channels)
else:
self.norm1 = KVAECachedSpatialNorm3D(in_channels, zq_ch, add_conv=add_conv)
self.conv1 = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
if zq_ch is None:
self.norm2 = KVAECachedGroupNorm(out_channels)
else:
self.norm2 = KVAECachedSpatialNorm3D(out_channels, zq_ch, add_conv=add_conv)
self.conv2 = KVAECachedCausalConv3d(chan_in=out_channels, chan_out=out_channels, kernel_size=3)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=out_channels, kernel_size=3)
else:
self.nin_shortcut = KVAESafeConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: torch.Tensor, layer_cache: Dict, zq: torch.Tensor = None) -> torch.Tensor:
h = x
if zq is None:
# Encoder path - norm takes cache
h = self.norm1(h, cache=layer_cache["norm1"])
else:
# Decoder path - spatial norm takes zq and cache
h = self.norm1(h, zq, cache=layer_cache["norm1"])
h = F.silu(h)
h = self.conv1(h, cache=layer_cache["conv1"])
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None, None]
if zq is None:
h = self.norm2(h, cache=layer_cache["norm2"])
else:
h = self.norm2(h, zq, cache=layer_cache["norm2"])
h = F.silu(h)
h = self.conv2(h, cache=layer_cache["conv2"])
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x, cache=layer_cache["conv_shortcut"])
else:
x = self.nin_shortcut(x)
return x + h
class KVAECachedPXSDownsample(nn.Module):
r"""
A 3D downsampling layer using PixelUnshuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.unshuffle = nn.PixelUnshuffle(self.factor)
self.s_pool = nn.AvgPool3d((1, 2, 2), (1, 2, 2))
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(2, 1, 1), dilation=(1, 1, 1)
)
self.linear = nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_downsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
pxs_input = input.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
# pxs_input = rearrange(input, 'b c t h w -> (b t) c h w')
pxs_interm = self.unshuffle(pxs_input)
b_it, c_it, h_it, w_it = pxs_interm.shape
pxs_interm_view = pxs_interm.view(b_it, c_it // self.factor**2, self.factor**2, h_it, w_it)
pxs_out = torch.mean(pxs_interm_view, dim=2)
pxs_out = pxs_out.view(b, t, -1, h_it, w_it).permute(0, 2, 1, 3, 4)
# pxs_out = rearrange(pxs_out, '(b t) c h w -> b c t h w', t=input.size(2))
conv_out = self.spatial_conv(input)
return conv_out + pxs_out
def temporal_downsample(self, input: torch.Tensor, cache: list) -> torch.Tensor:
b, c, t, h, w = input.shape
permuted = input.permute(0, 3, 4, 1, 2).reshape(b * h * w, c, t)
if cache[0]["padding"] is None:
first, rest = permuted[..., :1], permuted[..., 1:]
if rest.size(-1) > 0:
rest_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
full_interp = torch.cat([first, rest_interp], dim=-1)
else:
full_interp = first
else:
rest = permuted
if rest.size(-1) > 0:
full_interp = F.avg_pool1d(rest, kernel_size=2, stride=2)
t_new = full_interp.size(-1)
full_interp = full_interp.view(b, h, w, c, t_new).permute(0, 3, 4, 1, 2)
conv_out = self.temporal_conv(input, cache[0])
return conv_out + full_interp
def forward(self, x: torch.Tensor, cache: list) -> torch.Tensor:
out = self.spatial_downsample(x)
if self.temporal_compress:
out = self.temporal_downsample(out, cache=cache)
return self.linear(out)
class KVAECachedPXSUpsample(nn.Module):
r"""
A 3D upsampling layer using PixelShuffle with caching.
"""
def __init__(self, in_channels: int, compress_time: bool, factor: int = 2):
super().__init__()
self.temporal_compress = compress_time
self.factor = factor
self.shuffle = nn.PixelShuffle(self.factor)
self.spatial_conv = KVAESafeConv3d(
in_channels,
in_channels,
kernel_size=(1, 3, 3),
stride=(1, 1, 1),
padding=(0, 1, 1),
padding_mode="reflect",
)
if self.temporal_compress:
self.temporal_conv = KVAECachedCausalConv3d(
in_channels, in_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), dilation=(1, 1, 1)
)
self.linear = KVAESafeConv3d(in_channels, in_channels, kernel_size=1, stride=1)
def spatial_upsample(self, input: torch.Tensor) -> torch.Tensor:
b, c, t, h, w = input.shape
input_view = input.permute(0, 2, 1, 3, 4).reshape(b, t * c, h, w)
input_interp = F.interpolate(input_view, scale_factor=2, mode="nearest")
input_interp = input_interp.view(b, t, c, 2 * h, 2 * w).permute(0, 2, 1, 3, 4)
out = self.spatial_conv(input_interp)
return input_interp + out
def temporal_upsample(self, input: torch.Tensor, cache: Dict) -> torch.Tensor:
time_factor = 1.0 + 1.0 * (input.size(2) > 1)
if isinstance(time_factor, torch.Tensor):
time_factor = time_factor.item()
repeated = input.repeat_interleave(int(time_factor), dim=2)
if cache["padding"] is None:
tail = repeated[..., int(time_factor - 1) :, :, :]
else:
tail = repeated
conv_out = self.temporal_conv(tail, cache)
return conv_out + tail
def forward(self, x: torch.Tensor, cache: Dict) -> torch.Tensor:
if self.temporal_compress:
x = self.temporal_upsample(x, cache)
s_out = self.spatial_upsample(x)
to = torch.empty_like(s_out)
lin_out = self.linear(s_out, write_to=to)
return lin_out
# =============================================================================
# Cached Encoder/Decoder
# =============================================================================
class KVAECachedEncoder3D(nn.Module):
r"""
Cached 3D Encoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
in_channels: int = 3,
z_channels: int = 16,
double_z: bool = True,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.temporal_compress_level = int(np.log2(temporal_compress_times))
self.conv_in = KVAECachedCausalConv3d(chan_in=in_channels, chan_out=self.ch, kernel_size=3)
in_ch_mult = (1,) + tuple(ch_mult)
self.down = nn.ModuleList()
block_in = ch
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
dropout=dropout,
temb_channels=self.temb_ch,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
if i_level < self.temporal_compress_level:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=True)
else:
down.downsample = KVAECachedPXSDownsample(block_in, compress_time=False)
self.down.append(down)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
)
self.norm_out = KVAECachedGroupNorm(block_in)
self.conv_out = KVAECachedCausalConv3d(
chan_in=block_in, chan_out=2 * z_channels if double_z else z_channels, kernel_size=3
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
h = self.conv_in(x, cache=cache_dict["conv_in"])
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.down[i_level].block[i_block], h, temb, cache_dict[i_level][i_block]
)
else:
h = self.down[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block])
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
if i_level != self.num_resolutions - 1:
h = self.down[i_level].downsample(h, cache=cache_dict[i_level]["down"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"])
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"])
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"])
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"])
h = self.norm_out(h, cache=cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache=cache_dict["conv_out"])
return h
class KVAECachedDecoder3D(nn.Module):
r"""
Cached 3D Decoder for KVAE.
"""
def __init__(
self,
ch: int = 128,
out_ch: int = 3,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 16,
zq_ch: Optional[int] = None,
add_conv: bool = False,
temporal_compress_times: int = 4,
):
super().__init__()
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.temporal_compress_level = int(np.log2(temporal_compress_times))
if zq_ch is None:
zq_ch = z_channels
block_in = ch * ch_mult[self.num_resolutions - 1]
self.conv_in = KVAECachedCausalConv3d(chan_in=z_channels, chan_out=block_in, kernel_size=3)
self.mid = nn.Module()
self.mid.block_1 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.mid.block_2 = KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
KVAECachedResnetBlock3D(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
zq_ch=zq_ch,
add_conv=add_conv,
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
if i_level < self.num_resolutions - self.temporal_compress_level:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=False)
else:
up.upsample = KVAECachedPXSUpsample(block_in, compress_time=True)
self.up.insert(0, up)
self.norm_out = KVAECachedSpatialNorm3D(block_in, zq_ch, add_conv=add_conv)
self.conv_out = KVAECachedCausalConv3d(chan_in=block_in, chan_out=out_ch, kernel_size=3)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor, cache_dict: Dict) -> torch.Tensor:
temb = None
zq = z
h = self.conv_in(z, cache_dict["conv_in"])
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(self.mid.block_1, h, temb, cache_dict["mid_1"], zq)
h = self._gradient_checkpointing_func(self.mid.block_2, h, temb, cache_dict["mid_2"], zq)
else:
h = self.mid.block_1(h, temb, layer_cache=cache_dict["mid_1"], zq=zq)
h = self.mid.block_2(h, temb, layer_cache=cache_dict["mid_2"], zq=zq)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
if torch.is_grad_enabled() and self.gradient_checkpointing:
h = self._gradient_checkpointing_func(
self.up[i_level].block[i_block], h, temb, cache_dict[i_level][i_block], zq
)
else:
h = self.up[i_level].block[i_block](h, temb, layer_cache=cache_dict[i_level][i_block], zq=zq)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h, zq)
if i_level != 0:
h = self.up[i_level].upsample(h, cache_dict[i_level]["up"])
h = self.norm_out(h, zq, cache_dict["norm_out"])
h = nonlinearity(h)
h = self.conv_out(h, cache_dict["conv_out"])
return h
# =============================================================================
# Main AutoencoderKL class
# =============================================================================
class AutoencoderKLKVAEVideo(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin):
r"""
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos. Used in
[KVAE](https://github.com/kandinskylab/kvae-1).
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for
all models (such as downloading or saving).
Parameters:
ch (`int`, *optional*, defaults to 128): Base channel count.
ch_mult (`Tuple[int]`, *optional*, defaults to `(1, 2, 4, 8)`): Channel multipliers per level.
num_res_blocks (`int`, *optional*, defaults to 2): Number of residual blocks per level.
in_channels (`int`, *optional*, defaults to 3): Number of input channels.
out_ch (`int`, *optional*, defaults to 3): Number of output channels.
z_channels (`int`, *optional*, defaults to 16): Number of latent channels.
temporal_compress_times (`int`, *optional*, defaults to 4): Temporal compression factor.
"""
_supports_gradient_checkpointing = True
_no_split_modules = ["KVAECachedResnetBlock3D"]
@register_to_config
def __init__(
self,
ch: int = 128,
ch_mult: Tuple[int, ...] = (1, 2, 4, 8),
num_res_blocks: int = 2,
in_channels: int = 3,
out_ch: int = 3,
z_channels: int = 16,
temporal_compress_times: int = 4,
):
super().__init__()
self.encoder = KVAECachedEncoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
in_channels=in_channels,
z_channels=z_channels,
double_z=True,
temporal_compress_times=temporal_compress_times,
)
self.decoder = KVAECachedDecoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
out_ch=out_ch,
z_channels=z_channels,
temporal_compress_times=temporal_compress_times,
)
self.use_slicing = False
self.use_tiling = False
def _make_encoder_cache(self) -> Dict:
"""Create empty cache for cached encoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_enc"),
"mid_2": make_dict("resblock_enc"),
"norm_out": make_dict("norm_enc"),
"conv_out": make_dict("conv"),
}
# Encoder uses num_res_blocks per level
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_enc", p=self.config.num_res_blocks)
return cache
def _make_decoder_cache(self) -> Dict:
"""Create empty cache for decoder."""
def make_dict(name, p=None):
if name == "conv":
return {"padding": None}
layer, module = name.split("_")
if layer == "norm":
if module == "enc":
return {"mean": None, "var": None}
else:
return {"norm": make_dict("norm_enc"), "add_conv": make_dict("conv")}
elif layer == "resblock":
return {
"norm1": make_dict(f"norm_{module}"),
"norm2": make_dict(f"norm_{module}"),
"conv1": make_dict("conv"),
"conv2": make_dict("conv"),
"conv_shortcut": make_dict("conv"),
}
elif layer.isdigit():
out_dict = {"down": [make_dict("conv"), make_dict("conv")], "up": make_dict("conv")}
for i in range(p):
out_dict[i] = make_dict(f"resblock_{module}")
return out_dict
cache = {
"conv_in": make_dict("conv"),
"mid_1": make_dict("resblock_dec"),
"mid_2": make_dict("resblock_dec"),
"norm_out": make_dict("norm_dec"),
"conv_out": make_dict("conv"),
}
for i in range(len(self.config.ch_mult)):
cache[i] = make_dict(f"{i}_dec", p=self.config.num_res_blocks + 1)
return cache
def enable_slicing(self) -> None:
r"""Enable sliced VAE decoding."""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""Disable sliced VAE decoding."""
self.use_slicing = False
def _encode(self, x: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
# Cached encoder processes by segments
cache = self._make_encoder_cache()
split_list = [seg_len + 1]
n_frames = x.size(2) - (seg_len + 1)
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
latent = []
for chunk in torch.split(x, split_list, dim=2):
l = self.encoder(chunk, cache)
sample, _ = torch.chunk(l, 2, dim=1)
latent.append(sample)
return torch.cat(latent, dim=2)
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
"""
Encode a batch of videos into latents.
Args:
x (`torch.Tensor`): Input batch of videos with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
Returns:
The latent representations of the encoded videos.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
# For cached encoder, we already did the split in _encode
h_double = torch.cat([h, torch.zeros_like(h)], dim=1)
posterior = DiagonalGaussianDistribution(h_double)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def _decode(self, z: torch.Tensor, seg_len: int = 16) -> torch.Tensor:
cache = self._make_decoder_cache()
temporal_compress = self.config.temporal_compress_times
split_list = [seg_len + 1]
n_frames = temporal_compress * (z.size(2) - 1) - seg_len
while n_frames > 0:
split_list.append(seg_len)
n_frames -= seg_len
split_list[-1] += n_frames
split_list = [math.ceil(size / temporal_compress) for size in split_list]
recs = []
for chunk in torch.split(z, split_list, dim=2):
out = self.decoder(chunk, cache)
recs.append(out)
return torch.cat(recs, dim=2)
@apply_forward_hook
def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
"""
Decode a batch of videos.
Args:
z (`torch.Tensor`): Input batch of latent vectors with shape (B, C, T, H, W).
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
Returns:
[`~models.vae.DecoderOutput`] or `tuple`: Decoded video.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z)
if not return_dict:
return (decoded,)
return DecoderOutput(sample=decoded)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)

View File

@@ -105,7 +105,14 @@ class QwenImageRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class QwenImageUpsample(nn.Upsample):

View File

@@ -196,7 +196,14 @@ class WanRMS_norm(nn.Module):
self.bias = nn.Parameter(torch.zeros(shape)) if bias else 0.0
def forward(self, x):
return F.normalize(x, dim=(1 if self.channel_first else -1)) * self.scale * self.gamma + self.bias
needs_fp32_normalize = x.dtype in (torch.float16, torch.bfloat16) or any(
t in str(x.dtype) for t in ("float4_", "float8_")
)
normalized = F.normalize(x.float() if needs_fp32_normalize else x, dim=(1 if self.channel_first else -1)).to(
x.dtype
)
return normalized * self.scale * self.gamma + self.bias
class WanUpsample(nn.Upsample):

View File

@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from math import prod
from typing import Any
@@ -25,7 +24,7 @@ import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import apply_lora_scale, deprecate, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ...utils.torch_utils import lru_cache_unless_export, maybe_allow_in_graph
from .._modeling_parallel import ContextParallelInput, ContextParallelOutput
from ..attention import AttentionMixin, FeedForward
from ..attention_dispatch import dispatch_attention_fn
@@ -307,7 +306,7 @@ class QwenEmbedRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=128)
@lru_cache_unless_export(maxsize=128)
def _compute_video_freqs(
self, frame: int, height: int, width: int, idx: int = 0, device: torch.device = None
) -> torch.Tensor:
@@ -428,7 +427,7 @@ class QwenEmbedLayer3DRope(nn.Module):
return vid_freqs, txt_freqs
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_video_freqs(self, frame, height, width, idx=0, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -450,7 +449,7 @@ class QwenEmbedLayer3DRope(nn.Module):
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
return freqs.clone().contiguous()
@functools.lru_cache(maxsize=None)
@lru_cache_unless_export(maxsize=None)
def _compute_condition_freqs(self, frame, height, width, device: torch.device = None):
seq_lens = frame * height * width
pos_freqs = self.pos_freqs.to(device) if device is not None else self.pos_freqs
@@ -934,6 +933,7 @@ class QwenImageTransformer2DModel(
batch_size, image_seq_len = hidden_states.shape[:2]
image_mask = torch.ones((batch_size, image_seq_len), dtype=torch.bool, device=hidden_states.device)
joint_attention_mask = torch.cat([encoder_hidden_states_mask, image_mask], dim=1)
joint_attention_mask = joint_attention_mask[:, None, None, :]
block_attention_kwargs["attention_mask"] = joint_attention_mask
for index_block, block in enumerate(self.transformer_blocks):

View File

@@ -788,9 +788,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
if all(seq == max_seqlen for seq in item_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
@@ -871,9 +874,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
if all(seq == max_seqlen for seq in unified_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None

View File

@@ -285,6 +285,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
@@ -728,6 +729,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import (
LTXConditionPipeline,

View File

@@ -324,17 +324,18 @@ class AudioLDM2Pipeline(DiffusionPipeline):
`inputs_embeds (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`):
The sequence of generated hidden-states.
"""
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
max_new_tokens = max_new_tokens if max_new_tokens is not None else self.language_model.config.max_new_tokens
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
if hasattr(self.language_model, "_get_initial_cache_position"):
cache_position_kwargs = {}
if is_transformers_version("<", "4.52.1"):
cache_position_kwargs["input_ids"] = inputs_embeds
else:
cache_position_kwargs["seq_length"] = inputs_embeds.shape[0]
cache_position_kwargs["device"] = (
self.language_model.device if getattr(self, "language_model", None) is not None else self.device
)
cache_position_kwargs["model_kwargs"] = model_kwargs
model_kwargs = self.language_model._get_initial_cache_position(**cache_position_kwargs)
for _ in range(max_new_tokens):
# prepare model inputs

View File

@@ -16,22 +16,29 @@ from typing import Callable
import numpy as np
import torch
import torchvision
import torchvision.transforms
import torchvision.transforms.functional
from transformers import AutoTokenizer, Qwen2_5_VLForConditionalGeneration
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput
from ...models import AutoencoderKLWan, CosmosTransformer3DModel
from ...schedulers import UniPCMultistepScheduler
from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring
from ...utils import (
is_cosmos_guardrail_available,
is_torch_xla_available,
is_torchvision_available,
logging,
replace_example_docstring,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import CosmosPipelineOutput
if is_torchvision_available():
import torchvision.transforms.functional
if is_cosmos_guardrail_available():
from cosmos_guardrail import CosmosSafetyChecker
else:

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,491 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import torch
from tqdm.auto import tqdm
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...schedulers import BlockRefinementScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
>>> model_id = "inclusionAI/LLaDA2.1-mini"
>>> model = AutoModelForCausalLM.from_pretrained(
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
... )
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
>>> scheduler = BlockRefinementScheduler()
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
>>> print(output.texts[0])
```
"""
@dataclass
class LLaDA2PipelineOutput(BaseOutput):
sequences: torch.LongTensor
texts: list[str] | None = None
class LLaDA2Pipeline(DiffusionPipeline):
r"""
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
vocab_size]`.
"""
model: Any
scheduler: BlockRefinementScheduler
tokenizer: Any
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
def __init__(
self,
model: Any,
scheduler: BlockRefinementScheduler,
tokenizer: Any | None = None,
):
super().__init__()
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
@property
def num_timesteps(self):
return self._num_timesteps
# --- Prompt encoding ---
def _prepare_input_ids(
self,
*,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
use_chat_template: bool,
add_generation_prompt: bool,
chat_template_kwargs: dict[str, Any] | None,
) -> torch.LongTensor:
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
if input_ids is not None:
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 2:
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
return input_ids
if self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and prompt is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if messages is None and prompt is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
chat_template_kwargs = chat_template_kwargs or {}
if messages is not None:
encoded = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
if isinstance(prompt, list):
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
encoded = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]
def check_inputs(
self,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
gen_length: int,
block_length: int,
num_inference_steps: int,
minimal_topk: int,
threshold: float,
sampling_method: str,
output_type: str,
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
callback_on_step_end_tensor_inputs: list[str] | None,
):
# Input source validation
if prompt is None and messages is None and input_ids is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
if prompt is not None and messages is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if input_ids is not None:
if input_ids.ndim not in (1, 2):
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
if prompt is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
# Generation parameter validation
if gen_length <= 0:
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
if block_length <= 0:
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
if minimal_topk <= 0:
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
if sampling_method not in {"auto", "greedy", "multinomial"}:
raise ValueError(
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
)
if output_type not in {"seq", "text"}:
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
# Callback validation
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] | None = None,
messages: list[dict[str, str]] | None = None,
input_ids: torch.LongTensor | None = None,
use_chat_template: bool = True,
add_generation_prompt: bool = True,
gen_length: int = 2048,
block_length: int = 32,
num_inference_steps: int = 32,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "multinomial",
threshold: float = 0.7,
editing_threshold: float | None = 0.5,
max_post_steps: int = 16,
minimal_topk: int = 1,
eos_early_stop: bool = True,
eos_token_id: int | None = None,
mask_token_id: int | None = None,
generator: torch.Generator | None = None,
output_type: str = "text",
return_dict: bool = True,
callback_on_step_end: Callable[[int, int, dict], None]
| PipelineCallback
| MultiPipelineCallbacks
| None = None,
callback_on_step_end_tensor_inputs: list[str] | None = None,
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
"""
Generate text with block-wise refinement.
Args:
prompt (`str` or `List[str]`, *optional*):
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
available, the prompt is wrapped in a chat message before tokenization.
messages (`List[Dict[str, str]]`, *optional*):
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
when provided. Requires a tokenizer with `apply_chat_template`.
input_ids (`torch.LongTensor`, *optional*):
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
use_chat_template (`bool`, defaults to `True`):
Whether to wrap the prompt in a chat template.
add_generation_prompt (`bool`, defaults to `True`):
Whether to add the generation prompt when using chat templates.
gen_length (`int`):
Number of tokens to generate.
block_length (`int`):
Block size for refinement.
num_inference_steps (`int`):
Number of refinement steps per block.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`):
Confidence threshold for committing tokens.
editing_threshold (`float`, *optional*):
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
negative value to disable editing. Defaults to `0.5`.
max_post_steps (`int`):
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
used when `editing_threshold` is enabled. Defaults to `16`.
minimal_topk (`int`):
Minimum number of tokens to commit per step.
eos_early_stop (`bool`):
Whether to stop after committing EOS in a block.
eos_token_id (`int`, *optional*):
EOS token ID to use for early stopping.
mask_token_id (`int`, *optional*):
Mask token ID to use for the template.
generator (`torch.Generator`, *optional*):
RNG for sampling.
output_type (`str`, defaults to `"text"`):
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
token ID sequences only.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
timestep: int, callback_kwargs: Dict)`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
`confidence`, `active_block`.
Examples:
"""
# 1. Check inputs early
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["block_x"]
self.check_inputs(
prompt=prompt,
messages=messages,
input_ids=input_ids,
gen_length=gen_length,
block_length=block_length,
num_inference_steps=num_inference_steps,
minimal_topk=minimal_topk,
threshold=threshold,
sampling_method=sampling_method,
output_type=output_type,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Prepare input IDs from prompt/messages/input_ids
prompt_ids = self._prepare_input_ids(
prompt=prompt,
messages=messages,
input_ids=input_ids,
use_chat_template=use_chat_template,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=None,
)
device = self._execution_device
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(device=device)
batch_size, prompt_length = prompt_ids.shape
if eos_token_id is None:
eos_token_id = self.eos_token_id
if mask_token_id is None:
mask_token_id = self.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 3. Build attention mask and position IDs
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
# 4. Prepare latents (fully masked sequence)
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
if prompt_length > 0:
x[:, :prompt_length] = prompt_ids
prefill_blocks = prompt_length // block_length
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
global_step = 0
# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]
# Identify which positions in the block are prompt (non-editable).
block_start_pos = num_block * block_length
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
if block_start_pos < prompt_length:
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
prompt_mask_in_block[:prompt_end_in_block] = True
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)
while should_continue:
block_tokens = block_x[:, -block_length:]
masks_remaining = (block_tokens == mask_token_id).any()
if not masks_remaining:
post_steps += 1
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
block_logits = logits[:, -block_length:, :]
scheduler_output = self.scheduler.step(
model_output=block_logits,
timestep=step_idx,
sample=block_tokens,
mask_token_id=mask_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k,
sampling_method=sampling_method,
threshold=threshold,
editing_threshold=editing_threshold,
minimal_topk=minimal_topk,
prompt_mask=prompt_mask_in_block,
generator=generator,
return_dict=True,
)
transfer_index = scheduler_output.transfer_index
editing_transfer_index = scheduler_output.editing_transfer_index
final_transfer = transfer_index | editing_transfer_index
if final_transfer.any():
block_x[:, -block_length:] = scheduler_output.prev_sample
if eos_early_stop and eos_token_id is not None:
finished = self.scheduler.check_eos_finished(
cur_x=block_x,
sampled_tokens=scheduler_output.sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_token_id,
mask_token_id=mask_token_id,
prompt_length=prompt_length,
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
block_x = callback_outputs.pop("block_x", block_x)
global_step += 1
if masks_remaining:
step_idx += 1
progress_bar.update(1)
should_continue = self.scheduler.check_block_should_continue(
step_idx=step_idx,
masks_remaining=masks_remaining,
editing_enabled=editing_enabled,
editing_transfer_index=editing_transfer_index,
post_steps=post_steps,
max_post_steps=max_post_steps,
finished=finished,
)
progress_bar.close()
x[:, :current_window_end] = block_x
if eos_early_stop and finished.all():
break
# 6. Post-process output
generated = x[:, : prompt_length + gen_length]
sequences = generated[:, prompt_length:]
if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
texts = None
if output_type == "text" and self.tokenizer is not None:
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
if not return_dict:
return sequences.to(device=device), texts
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]

View File

@@ -40,6 +40,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
from .scheduling_amused import AmusedScheduler
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler

View File

@@ -0,0 +1,460 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class BlockRefinementSchedulerOutput(BaseOutput):
"""
Output class for block refinement scheduling.
Args:
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Updated block tokens after the current refinement step.
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were committed (mask-filling).
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were edited (non-mask replacement).
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Sampled token IDs from the model logits.
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
Probabilities of the sampled tokens.
"""
prev_sample: torch.LongTensor
transfer_index: torch.BoolTensor
editing_transfer_index: torch.BoolTensor
sampled_tokens: torch.LongTensor
sampled_probs: torch.Tensor
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler for block-wise iterative refinement (commit-by-confidence).
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
the number of refinement steps.
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
"""
order = 1
@register_to_config
def __init__(
self,
block_length: int = 32,
num_inference_steps: int = 32,
threshold: float = 0.95,
editing_threshold: float | None = None,
minimal_topk: int = 1,
):
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
self._transfer_schedule: torch.LongTensor | None = None
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
if num_inference_steps <= 0:
return torch.zeros((0,), dtype=torch.long)
base = block_length // num_inference_steps
remainder = block_length % num_inference_steps
out = torch.full((num_inference_steps,), base, dtype=torch.long)
out[:remainder] += 1
return out
# --- SAR sampling utilities ---
@staticmethod
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
"""Nucleus (top-p) logit filtering."""
if top_p is None or top_p >= 1.0:
return logits
if not (0.0 < top_p <= 1.0):
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > float(top_p)
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
return filtered
@staticmethod
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
"""Top-k logit filtering."""
if top_k is None or top_k <= 0:
return logits
if top_k >= logits.shape[-1]:
return logits
values, _ = torch.topk(logits, k=top_k, dim=-1)
min_keep = values[..., -1, None]
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
@staticmethod
def _sample_from_logits(
logits: torch.Tensor,
*,
temperature: float,
top_k: int | None,
top_p: float | None,
generator: torch.Generator | None,
use_multinomial: bool,
) -> tuple[torch.LongTensor, torch.Tensor]:
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
if temperature < 0:
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
vocab_size = logits.shape[-1]
flat_logits = logits.reshape(-1, vocab_size)
if temperature == 0.0 or not use_multinomial:
probs = torch.softmax(flat_logits.float(), dim=-1)
token = flat_logits.argmax(dim=-1, keepdim=True)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
scaled = flat_logits
if temperature != 1.0:
scaled = flat_logits / temperature
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
probs = torch.softmax(filtered.float(), dim=-1)
token = torch.multinomial(probs, num_samples=1, generator=generator)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.LongTensor,
*,
mask_token_id: int,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "auto",
threshold: float | None = None,
editing_threshold: float | None = None,
minimal_topk: int | None = None,
prompt_mask: torch.BoolTensor | None = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
) -> (
BlockRefinementSchedulerOutput
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
):
"""
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
ones.
Args:
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
Raw logits from the model for the current block.
timestep (`int` or `torch.Tensor`):
Current step index within the block's refinement schedule.
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Current block token IDs (contains mask tokens for uncommitted positions).
mask_token_id (`int`):
Token ID used for masked positions.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`, *optional*):
Confidence threshold for committing tokens. Defaults to config value.
editing_threshold (`float`, *optional*):
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
config value.
minimal_topk (`int`, *optional*):
Minimum tokens to commit per step. Defaults to config value.
prompt_mask (`torch.BoolTensor`, *optional*):
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
generator (`torch.Generator`, *optional*):
RNG for sampling.
return_dict (`bool`):
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
"""
if threshold is None:
threshold = float(self.config.threshold)
if editing_threshold is None:
editing_threshold = self.config.editing_threshold
if minimal_topk is None:
minimal_topk = self.config.minimal_topk
# Sample from logits
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
sampled_tokens, sampled_probs = self._sample_from_logits(
model_output,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=generator,
use_multinomial=use_multinomial,
)
batch_size, block_length = sample.shape
active_block = sample == mask_token_id
masks_remaining = active_block.any()
if isinstance(timestep, torch.Tensor):
step_index = int(timestep.item())
else:
step_index = int(timestep)
# --- Mask-filling transfer ---
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if masks_remaining and self._transfer_schedule is not None:
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
confidence = torch.where(
active_block,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
for b in range(batch_size):
high_conf = confidence[b] > threshold
if high_conf.sum().item() >= num_to_transfer:
transfer_index[b] = high_conf
else:
k = min(num_to_transfer, int(active_block[b].sum().item()))
if k > 0:
_, idx = torch.topk(confidence[b], k=k)
transfer_index[b, idx] = True
# --- Editing transfer (non-mask, non-prompt positions) ---
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if editing_enabled:
if prompt_mask is None:
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
editing_conf = torch.where(
editable,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
high_conf_edit = editing_conf > float(editing_threshold)
token_changed = sampled_tokens != sample
editing_transfer_index = high_conf_edit & token_changed & editable
# Apply transfers
final_transfer = transfer_index | editing_transfer_index
prev_sample = sample.clone()
if final_transfer.any():
prev_sample[final_transfer] = sampled_tokens[final_transfer]
if not return_dict:
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
return BlockRefinementSchedulerOutput(
prev_sample=prev_sample,
transfer_index=transfer_index,
editing_transfer_index=editing_transfer_index,
sampled_tokens=sampled_tokens,
sampled_probs=sampled_probs,
)
@staticmethod
def check_eos_finished(
cur_x: torch.LongTensor,
sampled_tokens: torch.LongTensor,
final_transfer: torch.BoolTensor,
finished: torch.BoolTensor,
eos_token_id: int,
mask_token_id: int,
prompt_length: int,
) -> torch.BoolTensor:
"""
Update per-batch finished flags when EOS tokens are committed.
Args:
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Current full sequence including all blocks up to the current window.
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Tokens sampled by the scheduler in this step.
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Combined mask of committed and edited positions.
finished (`torch.BoolTensor` of shape `(batch_size,)`):
Current per-batch finished flags.
eos_token_id (`int`):
EOS token ID.
mask_token_id (`int`):
Mask token ID.
prompt_length (`int`):
Number of prompt tokens at the start of the sequence.
Returns:
`torch.BoolTensor`: Updated finished flags.
"""
batch_size = cur_x.shape[0]
for b in range(batch_size):
if finished[b]:
continue
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
if not eos_in_commits:
continue
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True
return finished
def check_block_should_continue(
self,
step_idx: int,
masks_remaining: bool,
editing_enabled: bool,
editing_transfer_index: torch.BoolTensor,
post_steps: int,
max_post_steps: int,
finished: torch.BoolTensor,
) -> bool:
"""
Determine whether the inner refinement loop should continue for the current block.
Args:
step_idx (`int`):
Current refinement step index within this block.
masks_remaining (`bool`):
Whether any mask tokens remain in the block.
editing_enabled (`bool`):
Whether editing mode is active.
editing_transfer_index (`torch.BoolTensor`):
Which tokens were edited in this step.
post_steps (`int`):
Number of post-mask editing steps taken so far.
max_post_steps (`int`):
Maximum allowed post-mask editing steps.
finished (`torch.BoolTensor`):
Per-batch finished flags (from EOS detection).
Returns:
`bool`: `True` if refinement should continue, `False` to break.
"""
if finished.all():
return False
if not masks_remaining and not editing_enabled:
return False
if not masks_remaining and not editing_transfer_index.any():
return False
if masks_remaining and step_idx >= self.num_inference_steps:
return False
if not masks_remaining and post_steps > max_post_steps:
return False
return True
def add_noise(
self,
original_samples: torch.LongTensor,
attention_mask: torch.LongTensor,
*,
prompt_length: int,
block_length: int,
mask_token_id: int,
generator: torch.Generator | None = None,
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
"""
Apply the forward (noising) process for semi-autoregressive block masking.
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
one are the unmasked positions in the other.
Args:
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Clean token IDs.
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Padding mask (1 for valid, 0 for padding).
prompt_length (`int`):
Number of leading prompt tokens to keep unmasked.
block_length (`int`):
Block size for masking.
mask_token_id (`int`):
Token ID to use for masked positions.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
Returns:
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
corresponding boolean masks.
"""
batch_size, seq_len = original_samples.shape
device = original_samples.device
noisy = original_samples.clone()
noisy_rev = original_samples.clone()
masked = torch.zeros_like(original_samples, dtype=torch.bool)
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
valid = attention_mask.to(dtype=torch.bool)
for block_start in range(prompt_length, seq_len, block_length):
block_end = min(seq_len, block_start + block_length)
seg_len = block_end - block_start
if seg_len <= 0:
continue
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
seg = seg & valid[:, block_start:block_end]
seg_rev = (~seg) & valid[:, block_start:block_end]
masked[:, block_start:block_end] = seg
masked_rev[:, block_start:block_end] = seg_rev
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
return noisy, noisy_rev, masked, masked_rev
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]

View File

@@ -11,6 +11,7 @@ from typing import Any, Iterable
import numpy as np
import torch
import torch.nn.functional as F
if getattr(torch, "distributed", None) is not None:
@@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def compute_confidence_aware_loss(
logits: torch.Tensor,
labels: torch.Tensor,
*,
lambda_conf: float = 0.0,
temperature: float = 1.0,
per_token_weights: torch.Tensor | None = None,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes a confidence-aware training loss for token classification-style heads.
This loss combines:
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
Args:
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
are excluded from both losses.
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
sharpen the distribution and change the strength of the confidence gradients.
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
"""
if logits.ndim < 2:
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
if labels.shape != logits.shape[:-1]:
raise ValueError(
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
)
if temperature <= 0:
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
valid = labels.ne(ignore_index)
if per_token_weights is None:
weights = torch.ones_like(labels, dtype=logits.dtype)
else:
if per_token_weights.shape != labels.shape:
raise ValueError(
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
)
weights = per_token_weights.to(dtype=logits.dtype)
# Supervised CE (optionally weighted).
vocab_size = logits.shape[-1]
per_token_nll = F.cross_entropy(
logits.reshape(-1, vocab_size),
labels.reshape(-1),
reduction="none",
ignore_index=ignore_index,
).reshape_as(labels)
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
# Confidence loss: penalize entropy only where prediction is already correct.
if lambda_conf == 0.0:
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
return loss_sft, loss_sft, loss_conf
with torch.no_grad():
pred = logits.argmax(dim=-1)
correct = valid & pred.eq(labels)
scaled_logits = logits.float()
if temperature != 1.0:
scaled_logits = scaled_logits / float(temperature)
probs = torch.softmax(scaled_logits, dim=-1)
eps = torch.finfo(probs.dtype).tiny
log_probs = torch.log(probs.clamp_min(eps))
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
loss = loss_sft + float(lambda_conf) * loss_conf
return loss, loss_sft, loss_conf
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The

View File

@@ -521,6 +521,36 @@ class AutoencoderKLHunyuanVideo15(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAE(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLKVAEVideo(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class AutoencoderKLLTX2Audio(metaclass=DummyObject):
_backends = ["torch"]
@@ -2488,6 +2518,36 @@ class AmusedScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlockRefinementScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CMStochasticIterativeScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -2222,6 +2222,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2PipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LongCatImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -29,6 +29,7 @@ from numpy.linalg import norm
from packaging import version
from .constants import DIFFUSERS_REQUEST_TIMEOUT
from .deprecation_utils import deprecate
from .import_utils import (
BACKENDS_MAPPING,
is_accelerate_available,
@@ -67,9 +68,11 @@ else:
global_rng = random.Random()
logger = get_logger(__name__)
logger.warning(
"diffusers.utils.testing_utils' is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. "
deprecate(
"diffusers.utils.testing_utils",
"1.0.0",
"diffusers.utils.testing_utils is deprecated and will be removed in a future version. "
"Determinism and device backend utilities have been moved to `diffusers.utils.torch_utils`. ",
)
_required_peft_version = is_peft_available() and version.parse(
version.parse(importlib.metadata.version("peft")).base_version

View File

@@ -19,11 +19,16 @@ from __future__ import annotations
import functools
import os
from typing import Callable, ParamSpec, TypeVar
from . import logging
from .import_utils import is_torch_available, is_torch_mlu_available, is_torch_npu_available, is_torch_version
T = TypeVar("T")
P = ParamSpec("P")
if is_torch_available():
import torch
from torch.fft import fftn, fftshift, ifftn, ifftshift
@@ -333,5 +338,23 @@ def disable_full_determinism():
torch.use_deterministic_algorithms(False)
@functools.wraps(functools.lru_cache)
def lru_cache_unless_export(maxsize=128, typed=False):
def outer_wrapper(fn: Callable[P, T]):
cached = functools.lru_cache(maxsize=maxsize, typed=typed)(fn)
if is_torch_version("<", "2.7.0"):
return cached
@functools.wraps(fn)
def inner_wrapper(*args: P.args, **kwargs: P.kwargs):
if torch.compiler.is_exporting():
return fn(*args, **kwargs)
return cached(*args, **kwargs)
return inner_wrapper
return outer_wrapper
if is_torch_available():
torch_device = get_device()

View File

@@ -28,7 +28,6 @@ from diffusers.utils.import_utils import is_peft_available
from ..testing_utils import (
floats_tensor,
is_flaky,
require_peft_backend,
require_peft_version_greater,
skip_mps,
@@ -46,7 +45,6 @@ from .utils import PeftLoraLoaderMixinTests # noqa: E402
@require_peft_backend
@skip_mps
@is_flaky(max_attempts=10, description="very flaky class")
class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = WanVACEPipeline
scheduler_cls = FlowMatchEulerDiscreteScheduler
@@ -73,8 +71,8 @@ class WanVACELoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
"base_dim": 3,
"z_dim": 4,
"dim_mult": [1, 1, 1, 1],
"latents_mean": torch.randn(4).numpy().tolist(),
"latents_std": torch.randn(4).numpy().tolist(),
"latents_mean": [-0.7571, -0.7089, -0.9113, -0.7245],
"latents_std": [2.8184, 1.4541, 2.3275, 2.6558],
"num_res_blocks": 1,
"temperal_downsample": [False, True, True],
}

View File

@@ -0,0 +1,73 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAE
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_config(self):
return {
"in_channels": 3,
"channels": 32,
"num_enc_blocks": 1,
"num_dec_blocks": 1,
"z_channels": 4,
"double_z": True,
"ch_mult": (1, 2),
"sample_size": 32,
}
@property
def dummy_input(self):
batch_size = 2
num_channels = 3
sizes = (32, 32)
image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
return {"sample": image}
@property
def input_shape(self):
return (3, 32, 32)
@property
def output_shape(self):
return (3, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

View File

@@ -0,0 +1,118 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from diffusers import AutoencoderKLKVAEVideo
from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
enable_full_determinism()
class AutoencoderKLKVAEVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAEVideo
main_input_name = "sample"
base_precision = 1e-2
def get_autoencoder_kl_kvae_video_config(self):
return {
"ch": 32,
"ch_mult": (1, 2),
"num_res_blocks": 1,
"in_channels": 3,
"out_ch": 3,
"z_channels": 4,
"temporal_compress_times": 2,
}
@property
def dummy_input(self):
batch_size = 2
num_frames = 3 # satisfies (T-1) % temporal_compress_times == 0 with temporal_compress_times=2
num_channels = 3
sizes = (16, 16)
video = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
return {"sample": video}
@property
def input_shape(self):
return (3, 3, 16, 16)
@property
def output_shape(self):
return (3, 3, 16, 16)
def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAECachedEncoder3D",
"KVAECachedDecoder3D",
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_model_parallelism(self):
pass
@unittest.skip(
"Multi-GPU inference is not supported due to the stateful cache_dict passing through the forward pass."
)
def test_sharded_checkpoints_device_map(self):
pass
def _run_nondeterministic(self, fn):
# reflection_pad3d_backward_out_cuda has no deterministic CUDA implementation;
# temporarily relax the requirement for training tests that do backward passes.
import torch
torch.use_deterministic_algorithms(False)
try:
fn()
finally:
torch.use_deterministic_algorithms(True)
def test_training(self):
self._run_nondeterministic(super().test_training)
def test_ema_training(self):
self._run_nondeterministic(super().test_ema_training)
@unittest.skip(
"Gradient checkpointing recomputes the forward pass, but the model uses a stateful cache_dict "
"that is mutated during the first forward. On recomputation the cache is already populated, "
"causing a different execution path and numerically different gradients. "
"GC still reduces peak memory usage; gradient correctness in the presence of GC is a known limitation."
)
def test_effective_gradient_checkpointing(self):
pass
def test_layerwise_casting_training(self):
self._run_nondeterministic(super().test_layerwise_casting_training)

View File

@@ -481,6 +481,8 @@ class LoraHotSwappingForModelTesterMixin:
# ensure that enable_lora_hotswap is called before loading the first adapter
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
@@ -488,21 +490,31 @@ class LoraHotSwappingForModelTesterMixin:
msg = (
"It is recommended to call `enable_lora_hotswap` before loading the first adapter to avoid recompilation."
)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="warn")
assert any(msg in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
def test_enable_lora_hotswap_called_after_adapter_added_ignore(self, caplog):
# check possibility to ignore the error/warning
import logging
from diffusers.utils import logging as diffusers_logging
lora_config = self._get_lora_config(8, 8, target_modules=["to_q"])
init_dict = self.get_init_dict()
model = self.model_class(**init_dict).to(torch_device)
model.add_adapter(lora_config)
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
diffusers_logging.enable_propagation()
try:
with caplog.at_level(logging.WARNING):
model.enable_lora_hotswap(target_rank=32, check_compiled="ignore")
assert len(caplog.records) == 0
finally:
diffusers_logging.disable_propagation()
def test_enable_lora_hotswap_wrong_check_compiled_argument_raises(self):
# check that wrong argument value raises an error
@@ -518,20 +530,26 @@ class LoraHotSwappingForModelTesterMixin:
# check the error and log
import logging
from diffusers.utils import logging as diffusers_logging
# at the moment, PEFT requires the 2nd adapter to target the same or a subset of layers
target_modules0 = ["to_q"]
target_modules1 = ["to_q", "to_k"]
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
diffusers_logging.enable_propagation()
try:
with pytest.raises(RuntimeError): # peft raises RuntimeError
with caplog.at_level(logging.ERROR):
self._check_model_hotswap(
tmp_path,
do_compile=True,
rank0=8,
rank1=8,
target_modules0=target_modules0,
target_modules1=target_modules1,
)
assert any("Hotswapping adapter0 was unsuccessful" in record.message for record in caplog.records)
finally:
diffusers_logging.disable_propagation()
@pytest.mark.parametrize("rank0,rank1", [(11, 11), (7, 13), (13, 7)])
@require_torch_version_greater("2.7.1")

View File

@@ -22,6 +22,7 @@ import torch.distributed as dist
import torch.multiprocessing as mp
from diffusers.models._modeling_parallel import ContextParallelConfig
from diffusers.models.attention_dispatch import AttentionBackendName, _AttentionBackendRegistry
from ...testing_utils import (
is_context_parallel,
@@ -160,16 +161,21 @@ def _custom_mesh_worker(
@require_torch_multi_accelerator
class ContextParallelTesterMixin:
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_inference(self, cp_type):
def test_context_parallel_inference(self, cp_type, batch_size: int = 1):
if not torch.distributed.is_available():
pytest.skip("torch.distributed is not available.")
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
inputs_dict = self.get_dummy_inputs(batch_size=batch_size)
# Move all tensors to CPU for multiprocessing
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in inputs_dict.items()}
@@ -194,6 +200,10 @@ class ContextParallelTesterMixin:
f"Context parallel inference failed: {return_dict.get('error', 'Unknown error')}"
)
@pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"])
def test_context_parallel_batch_inputs(self, cp_type):
self.test_context_parallel_inference(cp_type, batch_size=2)
@pytest.mark.parametrize(
"cp_type,mesh_shape,mesh_dim_names",
[
@@ -209,6 +219,11 @@ class ContextParallelTesterMixin:
if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None:
pytest.skip("Model does not have a _cp_plan defined for context parallel inference.")
if cp_type == "ring_degree":
active_backend, _ = _AttentionBackendRegistry.get_active_backend()
if active_backend == AttentionBackendName.NATIVE:
pytest.skip("Ring attention is not supported with the native attention backend.")
world_size = 2
init_dict = self.get_init_dict()
inputs_dict = {k: v.cpu() if isinstance(v, torch.Tensor) else v for k, v in self.get_dummy_inputs().items()}

View File

@@ -41,7 +41,6 @@ from ..testing_utils import (
ModelOptCompileTesterMixin,
ModelOptTesterMixin,
ModelTesterMixin,
PyramidAttentionBroadcastTesterMixin,
QuantoCompileTesterMixin,
QuantoTesterMixin,
SingleFileTesterMixin,
@@ -151,8 +150,7 @@ class FluxTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": [4, 4, 8],
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
height = width = 4
num_latent_channels = 4
num_image_channels = 3
@@ -219,6 +217,10 @@ class TestFluxTransformerMemory(FluxTransformerTesterConfig, MemoryTesterMixin):
class TestFluxTransformerTraining(FluxTransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"FluxTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFluxTransformerAttention(FluxTransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux Transformer."""
@@ -412,10 +414,6 @@ class TestFluxTransformerBitsAndBytesCompile(FluxTransformerTesterConfig, BitsAn
"""BitsAndBytes + compile tests for Flux Transformer."""
class TestFluxTransformerPABCache(FluxTransformerTesterConfig, PyramidAttentionBroadcastTesterMixin):
"""PyramidAttentionBroadcast cache tests for Flux Transformer."""
class TestFluxTransformerFBCCache(FluxTransformerTesterConfig, FirstBlockCacheTesterMixin):
"""FirstBlockCache tests for Flux Transformer."""

View File

@@ -13,48 +13,94 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import Flux2Transformer2DModel, attention_backend
from diffusers import Flux2Transformer2DModel
from diffusers.models.transformers.transformer_flux2 import (
Flux2KVAttnProcessor,
Flux2KVCache,
Flux2KVLayerCache,
Flux2KVParallelSelfAttnProcessor,
)
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
from ..testing_utils import (
AttentionTesterMixin,
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ContextParallelTesterMixin,
GGUFCompileTesterMixin,
GGUFTesterMixin,
LoraHotSwappingForModelTesterMixin,
LoraTesterMixin,
MemoryTesterMixin,
ModelTesterMixin,
TorchAoCompileTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
enable_full_determinism()
class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
main_input_name = "hidden_states"
# We override the items here because the transformer under consideration is small.
model_split_percents = [0.7, 0.6, 0.6]
# Skip setting testing with default: AttnProcessor
uses_custom_attn_processor = True
class Flux2TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def output_shape(self):
def input_shape(self) -> tuple[int, int]:
return (16, 4)
def prepare_dummy_input(self, height=4, width=4):
batch_size = 1
@property
def model_split_percents(self) -> list:
# We override the items here because the transformer under consideration is small.
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
# Skip setting testing with default: AttnProcessor
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
"num_single_layers": 1,
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"axes_dims_rope": [4, 4, 4, 4],
}
def get_dummy_inputs(self, height: int = 4, width: int = 4, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
@@ -82,8 +128,286 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
"guidance": guidance,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
class TestFlux2Transformer(Flux2TransformerTesterConfig, ModelTesterMixin):
pass
class TestFlux2TransformerMemory(Flux2TransformerTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for Flux2 Transformer."""
class TestFlux2TransformerTraining(Flux2TransformerTesterConfig, TrainingTesterMixin):
"""Training tests for Flux2 Transformer."""
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestFlux2TransformerAttention(Flux2TransformerTesterConfig, AttentionTesterMixin):
"""Attention processor tests for Flux2 Transformer."""
class TestFlux2TransformerContextParallel(Flux2TransformerTesterConfig, ContextParallelTesterMixin):
"""Context Parallel inference tests for Flux2 Transformer."""
class TestFlux2TransformerLoRA(Flux2TransformerTesterConfig, LoraTesterMixin):
"""LoRA adapter tests for Flux2 Transformer."""
class TestFlux2TransformerLoRAHotSwap(Flux2TransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for Flux2 Transformer."""
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for LoRA hotswap tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerCompile(Flux2TransformerTesterConfig, TorchCompileTesterMixin):
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
"""Override to support dynamic height/width for compilation tests."""
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerBitsAndBytes(Flux2TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for Flux2 Transformer."""
class TestFlux2TransformerTorchAo(Flux2TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for Flux2 Transformer."""
class TestFlux2TransformerGGUF(Flux2TransformerTesterConfig, GGUFTesterMixin):
"""GGUF quantization tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerTorchAoCompile(Flux2TransformerTesterConfig, TorchAoCompileTesterMixin):
"""TorchAO + compile tests for Flux2 Transformer."""
class TestFlux2TransformerGGUFCompile(Flux2TransformerTesterConfig, GGUFCompileTesterMixin):
"""GGUF + compile tests for Flux2 Transformer."""
@property
def gguf_filename(self):
return "https://huggingface.co/unsloth/FLUX.2-dev-GGUF/blob/main/flux2-dev-Q2_K.gguf"
@property
def torch_dtype(self):
return torch.bfloat16
def get_dummy_inputs(self):
"""Override to provide inputs matching the real FLUX2 model dimensions.
Flux2 defaults: in_channels=128, joint_attention_dim=15360
"""
batch_size = 1
height = 64
width = 64
sequence_length = 512
hidden_states = randn_tensor(
(batch_size, height * width, 128), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, 15360), generator=self.generator, device=torch_device, dtype=self.torch_dtype
)
# Flux2 uses 4D image/text IDs (t, h, w, l)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device, self.torch_dtype)
guidance = torch.tensor([3.5]).to(torch_device, self.torch_dtype)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class Flux2TransformerKVCacheTesterConfig(BaseModelTesterConfig):
num_ref_tokens = 4
@property
def model_class(self):
return Flux2Transformer2DModel
@property
def output_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def input_shape(self) -> tuple[int, int]:
return (16, 4)
@property
def model_split_percents(self) -> list:
return [0.7, 0.6, 0.6]
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def uses_custom_attn_processor(self) -> bool:
return True
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict[str, int | list[int]]:
return {
"patch_size": 1,
"in_channels": 4,
"num_layers": 1,
@@ -91,72 +415,210 @@ class Flux2TransformerTests(ModelTesterMixin, unittest.TestCase):
"attention_head_dim": 16,
"num_attention_heads": 2,
"joint_attention_dim": 32,
"timestep_guidance_channels": 256, # Hardcoded in original code
"timestep_guidance_channels": 256,
"axes_dims_rope": [4, 4, 4, 4],
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
batch_size = 1
num_latent_channels = 4
sequence_length = 48
embedding_dim = 32
num_ref_tokens = self.num_ref_tokens
# TODO (Daniel, Sayak): We can remove this test.
def test_flux2_consistency(self, seed=0):
torch.manual_seed(seed)
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
ref_hidden_states = randn_tensor(
(batch_size, num_ref_tokens, num_latent_channels), generator=self.generator, device=torch_device
)
img_hidden_states = randn_tensor(
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
)
hidden_states = torch.cat([ref_hidden_states, img_hidden_states], dim=1)
torch.manual_seed(seed)
model = self.model_class(**init_dict)
# state_dict = model.state_dict()
# for key, param in state_dict.items():
# print(f"{key} | {param.shape}")
# torch.save(state_dict, "/raid/daniel_gu/test_flux2_params/diffusers.pt")
encoder_hidden_states = randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
)
ref_t_coords = torch.arange(1)
ref_h_coords = torch.arange(num_ref_tokens)
ref_w_coords = torch.arange(1)
ref_l_coords = torch.arange(1)
ref_ids = torch.cartesian_prod(ref_t_coords, ref_h_coords, ref_w_coords, ref_l_coords)
ref_ids = ref_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
t_coords = torch.arange(1)
h_coords = torch.arange(height)
w_coords = torch.arange(width)
l_coords = torch.arange(1)
image_ids = torch.cartesian_prod(t_coords, h_coords, w_coords, l_coords)
image_ids = image_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
image_ids = torch.cat([ref_ids, image_ids], dim=1)
text_t_coords = torch.arange(1)
text_h_coords = torch.arange(1)
text_w_coords = torch.arange(1)
text_l_coords = torch.arange(sequence_length)
text_ids = torch.cartesian_prod(text_t_coords, text_h_coords, text_w_coords, text_l_coords)
text_ids = text_ids.unsqueeze(0).expand(batch_size, -1, -1).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
guidance = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"img_ids": image_ids,
"txt_ids": text_ids,
"timestep": timestep,
"guidance": guidance,
}
class TestFlux2TransformerKVCache(Flux2TransformerKVCacheTesterConfig):
"""KV cache tests for Flux2 Transformer."""
def test_kv_layer_cache_store_and_get(self):
cache = Flux2KVLayerCache()
k = torch.randn(1, 4, 2, 16)
v = torch.randn(1, 4, 2, 16)
cache.store(k, v)
k_out, v_out = cache.get()
assert torch.equal(k, k_out)
assert torch.equal(v, v_out)
def test_kv_layer_cache_get_before_store_raises(self):
cache = Flux2KVLayerCache()
try:
cache.get()
assert False, "Expected RuntimeError"
except RuntimeError:
pass
def test_kv_layer_cache_clear(self):
cache = Flux2KVLayerCache()
cache.store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.k_ref is None
assert cache.v_ref is None
def test_kv_cache_structure(self):
num_double = 3
num_single = 2
cache = Flux2KVCache(num_double, num_single)
assert len(cache.double_block_caches) == num_double
assert len(cache.single_block_caches) == num_single
assert cache.num_ref_tokens == 0
for i in range(num_double):
assert isinstance(cache.get_double(i), Flux2KVLayerCache)
for i in range(num_single):
assert isinstance(cache.get_single(i), Flux2KVLayerCache)
def test_kv_cache_clear(self):
cache = Flux2KVCache(2, 1)
cache.num_ref_tokens = 4
cache.get_double(0).store(torch.randn(1, 4, 2, 16), torch.randn(1, 4, 2, 16))
cache.clear()
assert cache.num_ref_tokens == 0
assert cache.get_double(0).k_ref is None
def _set_kv_attn_processors(self, model):
for block in model.transformer_blocks:
block.attn.set_processor(Flux2KVAttnProcessor())
for block in model.single_transformer_blocks:
block.attn.set_processor(Flux2KVParallelSelfAttnProcessor())
@torch.no_grad()
def test_extract_mode_returns_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
self._set_kv_attn_processors(model)
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
assert output.kv_cache is not None
assert isinstance(output.kv_cache, Flux2KVCache)
assert output.kv_cache.num_ref_tokens == self.num_ref_tokens
for layer_cache in output.kv_cache.double_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
for layer_cache in output.kv_cache.single_block_caches:
assert layer_cache.k_ref is not None
assert layer_cache.v_ref is not None
@torch.no_grad()
def test_extract_mode_output_shape(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
with attention_backend("native"):
with torch.no_grad():
output = model(**inputs_dict)
height, width = 4, 4
output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
if isinstance(output, dict):
output = output.to_tuple()[0]
assert output.sample.shape == (1, height * width, 4)
self.assertIsNotNone(output)
@torch.no_grad()
def test_cached_mode_uses_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
# input & output have to have the same shape
input_tensor = inputs_dict[self.main_input_name]
expected_shape = input_tensor.shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
height, width = 4, 4
extract_output = model(
**self.get_dummy_inputs(height=height, width=width),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
)
# Check against expected slice
# fmt: off
expected_slice = torch.tensor([-0.3662, 0.4844, 0.6334, -0.3497, 0.2162, 0.0188, 0.0521, -0.2061, -0.2041, -0.0342, -0.7107, 0.4797, -0.3280, 0.7059, -0.0849, 0.4416])
# fmt: on
base_config = Flux2TransformerTesterConfig()
cached_inputs = base_config.get_dummy_inputs(height=height, width=width)
cached_output = model(
**cached_inputs,
kv_cache=extract_output.kv_cache,
kv_cache_mode="cached",
)
flat_output = output.cpu().flatten()
generated_slice = torch.cat([flat_output[:8], flat_output[-8:]])
self.assertTrue(torch.allclose(generated_slice, expected_slice, atol=1e-4))
assert cached_output.sample.shape == (1, height * width, 4)
assert cached_output.kv_cache is None
def test_gradient_checkpointing_is_applied(self):
expected_set = {"Flux2Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
@torch.no_grad()
def test_extract_return_dict_false(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
output = model(
**self.get_dummy_inputs(),
kv_cache_mode="extract",
num_ref_tokens=self.num_ref_tokens,
ref_fixed_timestep=0.0,
return_dict=False,
)
class Flux2TransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
assert isinstance(output, tuple)
assert len(output) == 2
assert isinstance(output[1], Flux2KVCache)
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
@torch.no_grad()
def test_no_kv_cache_mode_returns_no_cache(self):
model = self.model_class(**self.get_init_dict())
model.to(torch_device)
model.eval()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
base_config = Flux2TransformerTesterConfig()
output = model(**base_config.get_dummy_inputs())
class Flux2TransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
model_class = Flux2Transformer2DModel
different_shapes_for_compilation = [(4, 4), (4, 8), (8, 8)]
def prepare_init_args_and_inputs_for_common(self):
return Flux2TransformerTests().prepare_init_args_and_inputs_for_common()
def prepare_dummy_input(self, height, width):
return Flux2TransformerTests().prepare_dummy_input(height=height, width=width)
assert output.kv_cache is None

View File

@@ -14,6 +14,7 @@
import warnings
import pytest
import torch
from diffusers import QwenImageTransformer2DModel
@@ -77,8 +78,7 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
"axes_dims_rope": (8, 4, 4),
}
def get_dummy_inputs(self) -> dict[str, torch.Tensor]:
batch_size = 1
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
num_latent_channels = embedding_dim = 16
height = width = 4
sequence_length = 8
@@ -106,9 +106,10 @@ class QwenImageTransformerTesterConfig(BaseModelTesterConfig):
class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixin):
def test_infers_text_seq_len_from_mask(self):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_infers_text_seq_len_from_mask(self, batch_size):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
inputs = self.get_dummy_inputs(batch_size=batch_size)
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -122,7 +123,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert isinstance(per_sample_len, torch.Tensor)
assert int(per_sample_len.max().item()) == 2
assert normalized_mask.dtype == torch.bool
assert normalized_mask.sum().item() == 2
assert normalized_mask.sum().item() == 2 * batch_size
assert rope_text_seq_len >= inputs["encoder_hidden_states"].shape[1]
inputs["encoder_hidden_states_mask"] = normalized_mask
@@ -139,7 +140,7 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
)
assert int(per_sample_len2.max().item()) == 8
assert normalized_mask2.sum().item() == 5
assert normalized_mask2.sum().item() == 5 * batch_size
rope_text_seq_len_none, per_sample_len_none, normalized_mask_none = compute_text_seq_len_from_mask(
inputs["encoder_hidden_states"], None
@@ -149,9 +150,10 @@ class TestQwenImageTransformer(QwenImageTransformerTesterConfig, ModelTesterMixi
assert per_sample_len_none is None
assert normalized_mask_none is None
def test_non_contiguous_attention_mask(self):
@pytest.mark.parametrize("batch_size", [1, 2])
def test_non_contiguous_attention_mask(self, batch_size):
init_dict = self.get_init_dict()
inputs = self.get_dummy_inputs()
inputs = self.get_dummy_inputs(batch_size=batch_size)
model = self.model_class(**init_dict).to(torch_device)
encoder_hidden_states_mask = inputs["encoder_hidden_states_mask"].clone()
@@ -284,6 +286,14 @@ class TestQwenImageTransformerLoRA(QwenImageTransformerTesterConfig, LoraTesterM
class TestQwenImageTransformerLoRAHotSwap(QwenImageTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
"""LoRA hot-swapping tests for QwenImage Transformer."""
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_linear(self):
super().test_hotswapping_compiled_model_linear()
@pytest.mark.xfail(True, reason="Recompilation issues.", strict=True)
def test_hotswapping_compiled_model_both_linear_and_other(self):
super().test_hotswapping_compiled_model_both_linear_and_other()
@property
def different_shapes_for_compilation(self):
return [(4, 4), (4, 8), (8, 8)]

View File

@@ -13,58 +13,63 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers import SD3Transformer2DModel
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
from ...testing_utils import (
enable_full_determinism,
torch_device,
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import (
BaseModelTesterConfig,
BitsAndBytesTesterMixin,
ModelTesterMixin,
TorchAoTesterMixin,
TorchCompileTesterMixin,
TrainingTesterMixin,
)
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
# ======================== SD3 Transformer ========================
class SD3TransformerTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return SD3Transformer2DModel
@property
def dummy_input(self):
batch_size = 2
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-sd3-pipe"
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def model_split_percents(self) -> list:
return [0.8, 0.8, 0.9]
@property
def output_shape(self) -> tuple:
return (4, 32, 32)
@property
def input_shape(self) -> tuple:
return (4, 32, 32)
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
@@ -79,67 +84,79 @@ class SD3TransformerTests(ModelTesterMixin, unittest.TestCase):
"dual_attention_layers": (),
"qk_norm": None,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
"xformers is not enabled"
)
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SD3Transformer2DModel
main_input_name = "hidden_states"
model_split_percents = [0.8, 0.8, 0.9]
@property
def dummy_input(self):
batch_size = 2
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
pooled_prompt_embeds = torch.randn((batch_size, pooled_embedding_dim)).to(torch_device)
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
return {
"hidden_states": hidden_states,
"encoder_hidden_states": encoder_hidden_states,
"pooled_projections": pooled_prompt_embeds,
"timestep": timestep,
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
class TestSD3Transformer(SD3TransformerTesterConfig, ModelTesterMixin):
pass
class TestSD3TransformerTraining(SD3TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSD3TransformerCompile(SD3TransformerTesterConfig, TorchCompileTesterMixin):
pass
# ======================== SD3.5 Transformer ========================
class SD35TransformerTesterConfig(BaseModelTesterConfig):
@property
def input_shape(self):
def model_class(self):
return SD3Transformer2DModel
@property
def pretrained_model_name_or_path(self):
return "hf-internal-testing/tiny-sd35-pipe"
@property
def pretrained_model_kwargs(self):
return {"subfolder": "transformer"}
@property
def main_input_name(self) -> str:
return "hidden_states"
@property
def model_split_percents(self) -> list:
return [0.8, 0.8, 0.9]
@property
def output_shape(self) -> tuple:
return (4, 32, 32)
@property
def output_shape(self):
def input_shape(self) -> tuple:
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)
def get_init_dict(self) -> dict:
return {
"sample_size": 32,
"patch_size": 1,
"in_channels": 4,
@@ -154,47 +171,56 @@ class SD35TransformerTests(ModelTesterMixin, unittest.TestCase):
"dual_attention_layers": (0,),
"qk_norm": "rms_norm",
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
def get_dummy_inputs(self, batch_size: int = 2) -> dict[str, torch.Tensor]:
num_channels = 4
height = width = embedding_dim = 32
pooled_embedding_dim = embedding_dim * 2
sequence_length = 154
model.enable_xformers_memory_efficient_attention()
return {
"hidden_states": randn_tensor(
(batch_size, num_channels, height, width), generator=self.generator, device=torch_device
),
"encoder_hidden_states": randn_tensor(
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
),
"pooled_projections": randn_tensor(
(batch_size, pooled_embedding_dim), generator=self.generator, device=torch_device
),
"timestep": torch.randint(0, 1000, size=(batch_size,), generator=self.generator).to(torch_device),
}
assert model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor", (
"xformers is not enabled"
)
@unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply")
def test_set_attn_processor_for_determinism(self):
pass
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSD35Transformer(SD35TransformerTesterConfig, ModelTesterMixin):
def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()
model = self.model_class(**init_dict).to(torch_device)
# Forward pass without skipping layers
output_full = model(**inputs_dict).sample
# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip["skip_layers"] = [0]
output_skip = model(**inputs_dict_with_skip).sample
# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
)
assert not torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
assert output_full.shape == output_skip.shape, "Outputs should have the same shape"
# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")
class TestSD35TransformerTraining(SD35TransformerTesterConfig, TrainingTesterMixin):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
class TestSD35TransformerCompile(SD35TransformerTesterConfig, TorchCompileTesterMixin):
pass
class TestSD35TransformerBitsAndBytes(SD35TransformerTesterConfig, BitsAndBytesTesterMixin):
"""BitsAndBytes quantization tests for SD3.5 Transformer."""
class TestSD35TransformerTorchAo(SD35TransformerTesterConfig, TorchAoTesterMixin):
"""TorchAO quantization tests for SD3.5 Transformer."""

View File

@@ -5,6 +5,7 @@ from typing import Callable
import pytest
import torch
from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
@@ -32,6 +33,33 @@ from ..testing_utils import (
)
def _get_specified_components(path_or_repo_id, cache_dir=None):
if os.path.isdir(path_or_repo_id):
config_path = os.path.join(path_or_repo_id, "modular_model_index.json")
else:
try:
config_path = hf_hub_download(
repo_id=path_or_repo_id,
filename="modular_model_index.json",
local_dir=cache_dir,
)
except Exception:
return None
with open(config_path) as f:
config = json.load(f)
components = set()
for k, v in config.items():
if isinstance(v, (str, int, float, bool)):
continue
for entry in v:
if isinstance(entry, dict) and (entry.get("repo") or entry.get("pretrained_model_name_or_path")):
components.add(k)
break
return components
class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
@@ -360,6 +388,39 @@ class ModularPipelineTesterMixin:
assert torch.abs(image_slices[0] - image_slices[1]).max() < 1e-3
def test_load_expected_components_from_pretrained(self, tmp_path):
pipe = self.get_pipeline()
expected = _get_specified_components(self.pretrained_model_name_or_path, cache_dir=tmp_path)
if not expected:
pytest.skip("Skipping test as we couldn't fetch the expected components.")
actual = {
name
for name in pipe.components
if getattr(pipe, name, None) is not None
and getattr(getattr(pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, f"Component mismatch: missing={expected - actual}, unexpected={actual - expected}"
def test_load_expected_components_from_save_pretrained(self, tmp_path):
pipe = self.get_pipeline()
save_dir = str(tmp_path / "saved-pipeline")
pipe.save_pretrained(save_dir)
expected = _get_specified_components(save_dir)
loaded_pipe = ModularPipeline.from_pretrained(save_dir)
loaded_pipe.load_components(torch_dtype=torch.float32)
actual = {
name
for name in loaded_pipe.components
if getattr(loaded_pipe, name, None) is not None
and getattr(getattr(loaded_pipe, name), "_diffusers_load_id", None) not in (None, "null")
}
assert expected == actual, (
f"Component mismatch after save/load: missing={expected - actual}, unexpected={actual - expected}"
)
def test_modular_index_consistency(self, tmp_path):
pipe = self.get_pipeline()
components_spec = pipe._component_specs

View File

@@ -18,7 +18,7 @@ import unittest
import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import set_seed
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
from ..testing_utils import slow
@@ -85,3 +85,47 @@ class TrainingTests(unittest.TestCase):
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
def test_confidence_aware_loss(self):
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
labels = torch.tensor([[0, 0]])
weights = torch.tensor([[1.0, 2.0]])
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=0.0, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss, loss_sft))
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
lambda_conf = 0.25
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
)
# Manual expected values for the small 2-class case.
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
1, 2
)
expected_sft = (per_token_nll * weights).sum() / weights.sum()
pred = logits.argmax(dim=-1)
correct = pred.eq(labels)
log_probs = torch.log_softmax(logits.float(), dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
weights * correct.to(weights.dtype)
).sum().clamp_min(1)
expected = expected_sft + lambda_conf * expected_conf
self.assertTrue(torch.allclose(loss_sft, expected_sft))
self.assertTrue(torch.allclose(loss_conf, expected_conf))
self.assertTrue(torch.allclose(loss, expected))
# Temperature affects only the confidence term.
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))

View File

@@ -13,8 +13,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import os
import unittest
import warnings
import pytest
@@ -182,6 +184,25 @@ class DeprecateTester(unittest.TestCase):
assert str(warning.warning) == "This message is better!!!"
assert "diffusers/tests/others/test_utils.py" in warning.filename
def test_deprecate_testing_utils_module(self):
import diffusers.utils.testing_utils
with warnings.catch_warnings(record=True) as caught_warnings:
warnings.simplefilter("always")
importlib.reload(diffusers.utils.testing_utils)
deprecation_warnings = [w for w in caught_warnings if issubclass(w.category, FutureWarning)]
assert len(deprecation_warnings) >= 1, "Expected at least one FutureWarning from diffusers.utils.testing_utils"
messages = [str(w.message) for w in deprecation_warnings]
assert any("diffusers.utils.testing_utils" in msg for msg in messages), (
f"Expected a deprecation warning mentioning 'diffusers.utils.testing_utils', got: {messages}"
)
assert any(
"diffusers.utils.testing_utils is deprecated and will be removed in a future version." in msg
for msg in messages
), f"Expected deprecation message substring not found, got: {messages}"
# Copied from https://github.com/huggingface/transformers/blob/main/tests/utils/test_expectations.py
class ExpectationsTester(unittest.TestCase):

View File

View File

@@ -0,0 +1,245 @@
import unittest
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class _DummyModelOutput:
def __init__(self, logits):
self.logits = logits
class _DummyCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = int(vocab_size)
self.register_buffer("_device_anchor", torch.empty(0))
@property
def dtype(self):
return torch.float32
@property
def device(self):
return self._device_anchor.device
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
batch_size, seq_len = input_ids.shape
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
# Make confidence vary with token position so top-k commits are deterministic.
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
return _DummyModelOutput(logits=logits)
def _make_pipeline(tokenizer=None):
model = _DummyCausalLM(vocab_size=32)
scheduler = BlockRefinementScheduler()
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
class LLaDA2PipelineTest(unittest.TestCase):
def test_pipeline_runs(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
out = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=24,
block_length=8,
num_inference_steps=8,
temperature=0.0,
threshold=2.0, # force top-k commits
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
eos_token_id=None,
output_type="seq",
)
self.assertEqual(out.sequences.shape, (2, 24))
self.assertFalse((out.sequences == 31).any().item())
def test_pipeline_return_tuple(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
sequences, texts = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
return_dict=False,
)
self.assertEqual(sequences.shape, (1, 16))
self.assertIsNone(texts)
def test_output_type_seq(self):
"""output_type='seq' should return sequences but no texts."""
pipe = _make_pipeline().to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
)
self.assertIsNotNone(out.sequences)
self.assertEqual(out.sequences.shape, (1, 16))
self.assertIsNone(out.texts)
def test_output_type_text_without_tokenizer(self):
"""output_type='text' without a tokenizer should return texts=None."""
pipe = _make_pipeline(tokenizer=None).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNone(out.texts)
def test_output_type_text_with_tokenizer(self):
"""output_type='text' with a tokenizer should return decoded texts."""
tok = type(
"Tok",
(),
{
"eos_token_id": None,
"mask_token_id": 31,
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
},
)()
pipe = _make_pipeline(tokenizer=tok).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNotNone(out.texts)
self.assertEqual(len(out.texts), 1)
self.assertTrue(out.texts[0].startswith("decoded_"))
def test_output_type_invalid_raises(self):
"""Invalid output_type should raise ValueError."""
pipe = _make_pipeline().to("cpu")
with self.assertRaises(ValueError):
pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
mask_token_id=31,
output_type="invalid",
)
def test_prepare_input_ids_from_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertTrue(torch.equal(result, ids))
def test_prepare_input_ids_from_1d_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([1, 2, 3], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertEqual(result.shape, (1, 3))
def test_prepare_input_ids_no_tokenizer_raises(self):
pipe = _make_pipeline(tokenizer=None)
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
pipe = _make_pipeline()
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=[{"role": "user", "content": "hi"}],
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_neither_raises(self):
pipe = _make_pipeline()
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

View File

@@ -0,0 +1,470 @@
import tempfile
import unittest
import torch
from diffusers import BlockRefinementScheduler
class BlockRefinementSchedulerTest(unittest.TestCase):
def get_scheduler(self, **kwargs):
config = {
"block_length": 32,
"num_inference_steps": 8,
"threshold": 0.95,
"editing_threshold": None,
"minimal_topk": 1,
}
config.update(kwargs)
return BlockRefinementScheduler(**config)
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
"""Create logits where softmax of the target token has approximately the given probability."""
batch_size, block_length = target_probs.shape
logits = torch.zeros(batch_size, block_length, vocab_size)
# Set token 0 as the "predicted" token with a logit proportional to desired probability
for b in range(batch_size):
for t in range(block_length):
p = target_probs[b, t].item()
if p > 0:
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
return logits
def test_set_timesteps(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
self.assertEqual(scheduler.num_inference_steps, 8)
self.assertEqual(len(scheduler.timesteps), 8)
self.assertEqual(scheduler.timesteps[0].item(), 7)
self.assertEqual(scheduler.timesteps[-1].item(), 0)
def test_set_timesteps_invalid(self):
scheduler = self.get_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(0)
def test_get_num_transfer_tokens_even(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
self.assertEqual(schedule.sum().item(), 32)
self.assertEqual(len(schedule), 8)
self.assertTrue((schedule == 4).all().item())
def test_get_num_transfer_tokens_remainder(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
self.assertEqual(schedule.sum().item(), 10)
self.assertEqual(len(schedule), 3)
self.assertEqual(schedule[0].item(), 4)
self.assertEqual(schedule[1].item(), 3)
self.assertEqual(schedule[2].item(), 3)
def test_transfer_schedule_created_on_set_timesteps(self):
scheduler = self.get_scheduler(block_length=16)
scheduler.set_timesteps(4)
self.assertIsNotNone(scheduler._transfer_schedule)
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
def test_save_load_config_round_trip(self):
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
with tempfile.TemporaryDirectory() as tmpdir:
scheduler.save_config(tmpdir)
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
self.assertEqual(loaded.config.block_length, 64)
self.assertEqual(loaded.config.threshold, 0.8)
self.assertEqual(loaded.config.editing_threshold, 0.5)
self.assertEqual(loaded.config.minimal_topk, 2)
def test_from_config(self):
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
self.assertEqual(new_scheduler.config.block_length, 16)
self.assertEqual(new_scheduler.config.threshold, 0.7)
def test_step_commits_tokens(self):
"""Verify that step() commits mask tokens based on confidence."""
scheduler = self.get_scheduler(block_length=8)
scheduler.set_timesteps(2)
batch_size, block_length, vocab_size = 1, 8, 32
mask_id = 31
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
# Create logits where confidence decreases with position
logits = torch.zeros(batch_size, block_length, vocab_size)
for i in range(block_length):
logits[0, i, i] = 10.0 - i # decreasing confidence
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
threshold=0.95,
return_dict=True,
)
# With 8 tokens and 2 steps, first step should commit 4 tokens
committed = out.transfer_index[0].sum().item()
self.assertEqual(committed, 4)
def test_step_no_editing_by_default(self):
"""Without editing_threshold, no non-mask tokens should be changed."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 10.0 # predict token 15 for all positions
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=None,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index.any().item())
self.assertFalse(out.transfer_index[0, 0].item())
self.assertFalse(out.transfer_index[0, 1].item())
def test_step_editing_replaces_tokens(self):
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
# Token 0: predict 50 (different from 10) with very high logit
logits[0, 0, 15] = 20.0
# Token 1: predict 20 (same as current)
logits[0, 1, 20] = 20.0
# Mask tokens
logits[0, 2, 5] = 5.0
logits[0, 3, 6] = 5.0
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
return_dict=True,
)
# Token 0 should be edited (different prediction, high confidence)
self.assertTrue(out.editing_transfer_index[0, 0].item())
# Token 1 should NOT be edited (same prediction)
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_prompt_mask_prevents_editing(self):
"""Prompt positions should never be edited even with editing enabled."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 20.0
prompt_mask = torch.tensor([True, True, False, False])
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
prompt_mask=prompt_mask,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index[0, 0].item())
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_return_tuple(self):
"""Verify tuple output when return_dict=False."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.full((1, 4), 31, dtype=torch.long)
logits = torch.randn(1, 4, vocab_size)
result = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
return_dict=False,
)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 5)
def test_step_batched(self):
"""Verify step works with batch_size > 1."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
batch_size, vocab_size = 3, 32
mask_id = 31
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
logits = torch.randn(batch_size, 4, vocab_size)
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
return_dict=True,
)
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
def test_check_block_should_continue_finished(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([True, True])
result = scheduler.check_block_should_continue(
step_idx=0,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_no_masks_no_edits(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=5,
masks_remaining=False,
editing_enabled=True,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=1,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_steps_exhausted(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=8,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_eos_finished_marks_batch(self):
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, token, eos, mask, mask]
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
final_transfer = torch.tensor([[False, False, False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertTrue(finished[0].item())
def test_check_eos_finished_ignores_when_masks_before_eos(self):
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertFalse(finished[0].item())
def test_check_eos_finished_already_finished(self):
"""Already-finished batches should stay finished."""
mask_id, eos_id = 99, 2
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, False]])
finished = torch.tensor([True])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=2,
)
self.assertTrue(finished[0].item())
def test_add_noise(self):
scheduler = self.get_scheduler(block_length=4)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
mask_token_id = 99
gen = torch.Generator().manual_seed(42)
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=2,
block_length=4,
mask_token_id=mask_token_id,
generator=gen,
)
# Prompt positions should never be masked
self.assertFalse(masked[0, 0].item())
self.assertFalse(masked[0, 1].item())
self.assertFalse(masked_rev[0, 0].item())
self.assertFalse(masked_rev[0, 1].item())
# Noisy should have mask_token_id where masked is True
self.assertTrue((noisy[masked] == mask_token_id).all().item())
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
# masked and masked_rev should be complementary within valid non-prompt positions
non_prompt = torch.zeros_like(masked)
non_prompt[0, 2:] = True
combined = masked | masked_rev
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
class TestTopPFiltering(unittest.TestCase):
def test_top_p_filtering(self):
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
def test_top_p_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
self.assertTrue(torch.equal(result, logits))
def test_top_p_filtering_one(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
self.assertTrue(torch.equal(result, logits))
class TestTopKFiltering(unittest.TestCase):
def test_top_k_filtering(self):
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
def test_top_k_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_zero(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_large_k(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
self.assertTrue(torch.equal(result, logits))
class TestSampleFromLogits(unittest.TestCase):
def test_greedy_sampling(self):
logits = torch.tensor([[1.0, 5.0, 2.0]])
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 1)
self.assertEqual(tokens.shape, (1,))
self.assertEqual(probs.shape, (1,))
def test_multinomial_sampling(self):
logits = torch.tensor([[0.0, 100.0, -100.0]])
gen = torch.Generator().manual_seed(42)
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=1.0,
top_k=None,
top_p=None,
generator=gen,
use_multinomial=True,
)
self.assertEqual(tokens.item(), 1)
def test_temperature_scaling(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
tokens, _ = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.01,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 2)
def test_negative_temperature_raises(self):
logits = torch.tensor([[1.0, 2.0]])
with self.assertRaises(ValueError):
BlockRefinementScheduler._sample_from_logits(
logits,
temperature=-1.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
def fetch_pipeline_objects():
models = api.list_models(library="diffusers")
models = api.list_models(filter="diffusers")
downloads = defaultdict(int)
for model in models: