Compare commits

...

13 Commits

Author SHA1 Message Date
sayakpaul
af96109435 add a profiling worflow. 2026-03-26 17:01:41 +05:30
Sayak Paul
b757035df6 fix claude workflow to include id-token with write. (#13338) 2026-03-26 15:39:10 +05:30
kaixuanliu
41e1003316 avoid hardcode device in flux-control example (#13336)
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
2026-03-26 12:40:53 +05:30
Sayak Paul
85ffcf1db2 [tests] Tests for conditional pipeline blocks (#13247)
* implement test suite for conditional blocks.

* remove

* another fix.

* Revert "another fix."

This reverts commit ab07b603ab.
2026-03-26 08:48:16 +05:30
Steven Liu
cbf4d9a3c3 [docs] kernels (#13139)
* kernels

* feedback
2026-03-25 09:31:54 -07:00
Sayak Paul
426daabad9 [ci] claude in ci. (#13297)
* claude in ci.

* review feedback.
2026-03-25 21:30:06 +05:30
Kashif Rasul
762ae059fa [LLADA2] documentation fixes (#13333)
documentation fixes
2026-03-25 17:49:31 +05:30
Kashif Rasul
5d207e756e [Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion

Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
- BlockRefinementPipeline: block-wise iterative refinement with confidence-based
  token commitment, supporting editing threshold for LLaDA2.1 models
- LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults
- DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p,
  temperature) and prompt/prefix helpers
- compute_confidence_aware_loss: CAP-style training loss
- Examples: sampling scripts for LLaDA2 and block refinement, training scripts
  with Qwen causal LM
- Docs and tests included

* feat: add BlockRefinementScheduler for commit-by-confidence scheduling

Extract the confidence-based token commit logic from BlockRefinementPipeline
into a dedicated BlockRefinementScheduler, following diffusers conventions.

The scheduler owns:
- Transfer schedule computation (get_num_transfer_tokens)
- Timestep management (set_timesteps)
- Step logic: confidence-based mask-filling and optional token editing

The pipeline now delegates scheduling to self.scheduler.step() and accepts
a scheduler parameter in __init__.

* test: add unit tests for BlockRefinementScheduler

12 tests covering set_timesteps, get_num_transfer_tokens, step logic
(confidence-based commits, threshold behavior, editing, prompt masking,
batched inputs, tuple output).

* docs: add toctree entries and standalone scheduler doc page

- Add BlockRefinement and LLaDA2 to docs sidebar navigation
- Add BlockRefinementScheduler to schedulers sidebar navigation
- Move scheduler autodoc to its own page under api/schedulers/

* feat: add --revision flag and fix dtype deprecation in sample_llada2.py

- Add --revision argument for loading model revisions from the Hub
- Replace deprecated torch_dtype with dtype for transformers 5.x compat

* fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat

LLaDA2 models expect a boolean-style (1/0) attention mask, not an
additive (0/-inf) mask. The model internally converts to additive,
so passing 0/-inf caused double-masking and gibberish output.

* refactor: consolidate training scripts into single train_block_refinement.py

- Remove toy train_block_refinement_cap.py (self-contained demo with tiny model)
- Rename train_block_refinement_qwen_cap.py to train_block_refinement.py
  (already works with any causal LM via AutoModelForCausalLM)
- Fix torch_dtype deprecation and update README with correct script names

* fix formatting

* docs: improve LLaDA2 and BlockRefinement documentation

- Add usage examples with real model IDs and working code
- Add recommended parameters table for LLaDA2.1 quality/speed modes
- Note that editing is LLaDA2.1-only (not for LLaDA2.0 models)
- Remove misleading config defaults section from BlockRefinement docs

* feat: set LLaDA2Pipeline defaults to recommended model parameters

- threshold: 0.95 -> 0.7 (quality mode)
- max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0)
- eos_early_stop: False -> True (stop at EOS token)

block_length=32, steps=32, temperature=0.0 were already correct.
editing_threshold remains None (users enable for LLaDA2.1 models).

* feat: default editing_threshold=0.5 for LLaDA2.1 quality mode

LLaDA2.1 is the current generation. Users with LLaDA2.0 models can
disable editing by passing editing_threshold=None.

* fix: align sampling utilities with official LLaDA2 implementation

- top_p filtering: add shift-right to preserve at least one token above
  threshold (matches official code line 1210)
- temperature ordering: apply scaling before top-k/top-p filtering so
  filtering operates on scaled logits (matches official code lines 1232-1235)
- greedy branch: return argmax directly when temperature=0 without
  filtering (matches official code lines 1226-1230)

* refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids

LLaDA2Pipeline._prepare_prompt_ids was a near-copy of
DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate
and call the mixin method directly. Also simplify _extract_input_ids
since we always pass return_dict=True.

* formatting

* fix: replace deprecated torch_dtype with dtype in examples and docstrings

- Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID
- Fix sample_block_refinement.py to use dtype=

* remove BlockRefinementPipeline

* cleanup

* fix readme

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* removed DiscreteDiffusionPipelineMixin

* add support for 2d masks for flash attn

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix issues from review

* added tests

* formatting

* add check_eos_finished to scheduler

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix renaming issues and types

* remove duplicate check

* Update docs/source/en/api/pipelines/llada2.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-25 16:17:50 +05:30
Sayak Paul
e358ddcce6 fix to device and to dtype tests. (#13323) 2026-03-25 11:47:02 +05:30
Sayak Paul
153fcbc5a8 fix klein lora loading. (#13313) 2026-03-25 07:51:35 +05:30
Beinsezii
da6718f080 ZImageTransformer2D: Only build attention mask if seqlens are not equal (#12955) 2026-03-24 06:06:50 -10:00
Alexey Kirillov
832676d35e Use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING (#13320)
refactor: use defaultdict for _SET_ADAPTER_SCALE_FN_MAPPING

Co-authored-by: Alexkkir <alexkkir@gmail.coom>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
2026-03-24 17:49:50 +05:30
Dhruv Nair
7bbd96da5d [CI] Update fetching pipelines for latest HF Hub Version (#13322)
update
2026-03-24 16:42:32 +05:30
38 changed files with 3855 additions and 159 deletions

11
.ai/review-rules.md Normal file
View File

@@ -0,0 +1,11 @@
# PR Review Rules
Review-specific rules for Claude. Focus on correctness — style is handled by ruff.
Before reviewing, read and apply the guidelines in:
- [AGENTS.md](AGENTS.md) — coding style, dependencies, copied code, model conventions
- [skills/model-integration/SKILL.md](skills/model-integration/SKILL.md) — attention pattern, pipeline rules, implementation checklist, gotchas
- [skills/parity-testing/SKILL.md](skills/parity-testing/SKILL.md) — testing rules, comparison utilities
- [skills/parity-testing/pitfalls.md](skills/parity-testing/pitfalls.md) — known pitfalls (dtype mismatches, config assumptions, etc.)
## Common mistakes (add new rules below this line)

39
.github/workflows/claude_review.yml vendored Normal file
View File

@@ -0,0 +1,39 @@
name: Claude PR Review
on:
issue_comment:
types: [created]
pull_request_review_comment:
types: [created]
permissions:
contents: write
pull-requests: write
issues: read
id-token: write
jobs:
claude-review:
if: |
(
github.event_name == 'issue_comment' &&
github.event.issue.pull_request &&
github.event.issue.state == 'open' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
) || (
github.event_name == 'pull_request_review_comment' &&
contains(github.event.comment.body, '@claude') &&
(github.event.comment.author_association == 'MEMBER' ||
github.event.comment.author_association == 'OWNER' ||
github.event.comment.author_association == 'COLLABORATOR')
)
runs-on: ubuntu-latest
steps:
- uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
claude_args: |
--append-system-prompt "Review this PR against the rules in .ai/review-rules.md. Focus on correctness, not style (ruff handles style). Only review changes under src/diffusers/. Do NOT commit changes unless the comment explicitly asks you to using the phrase 'commit this'."

View File

@@ -670,6 +670,10 @@
- local: api/pipelines/z_image
title: Z-Image
title: Image
- sections:
- local: api/pipelines/llada2
title: LLaDA2
title: Text
- sections:
- local: api/pipelines/allegro
title: Allegro
@@ -718,6 +722,8 @@
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/block_refinement
title: BlockRefinementScheduler
- local: api/schedulers/cm_stochastic_iterative
title: CMStochasticIterativeScheduler
- local: api/schedulers/ddim_cogvideox

View File

@@ -0,0 +1,90 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
steps.
## Usage
```py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
model_id = "inclusionAI/LLaDA2.1-mini"
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
output = pipe(
prompt="Write a short poem about the ocean.",
gen_length=256,
block_length=32,
num_inference_steps=32,
threshold=0.7,
editing_threshold=0.5,
max_post_steps=16,
temperature=0.0,
)
print(output.texts[0])
```
## Callbacks
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
window.
```py
def on_step_end(pipe, step, timestep, callback_kwargs):
block_x = callback_kwargs["block_x"]
# Inspect or modify `block_x` here.
return {"block_x": block_x}
out = pipe(
prompt="Write a short poem.",
callback_on_step_end=on_step_end,
callback_on_step_end_tensor_inputs=["block_x"],
)
```
## Recommended parameters
LLaDA2.1 models support two modes:
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|------|-------------|---------------------|------------------|
| Quality | 0.7 | 0.5 | 16 |
| Speed | 0.5 | `None` | 16 |
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
## LLaDA2Pipeline
[[autodoc]] LLaDA2Pipeline
- all
- __call__
## LLaDA2PipelineOutput
[[autodoc]] pipelines.LLaDA2PipelineOutput

View File

@@ -63,6 +63,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
| [Latte](latte) | text2image |
| [LEDITS++](ledits_pp) | image editing |
| [LLaDA2](llada2) | text2text |
| [Lumina-T2X](lumina) | text2image |
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
| [MultiDiffusion](panorama) | text2image |

View File

@@ -0,0 +1,25 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# BlockRefinementScheduler
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
token with high confidence.
This scheduler is used by [`LLaDA2Pipeline`].
## BlockRefinementScheduler
[[autodoc]] BlockRefinementScheduler
## BlockRefinementSchedulerOutput
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput

View File

@@ -248,6 +248,24 @@ Refer to the [diffusers/benchmarks](https://huggingface.co/datasets/diffusers/be
The [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao#benchmarking-results) repository also contains benchmarking results for compiled versions of Flux and CogVideoX.
## Kernels
[Kernels](https://huggingface.co/docs/kernels/index) is a library for building, distributing, and loading optimized compute kernels on the [Hub](https://huggingface.co/kernels-community). It supports [attention](./attention_backends#set_attention_backend) kernels and custom CUDA kernels for operations like RMSNorm, GEGLU, RoPE, and AdaLN.
The [Diffusers Pipeline Integration](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/references/diffusers-integration.md) guide shows how to integrate a kernel with the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill. This skill enables an agent, like Claude or Codex, to write custom kernels targeted towards a specific model and your hardware.
> [!TIP]
> Install the [add cuda-kernels](https://github.com/huggingface/kernels/blob/main/skills/cuda-kernels/SKILL.md) skill to teach an agent how to write a kernel. The [Custom kernels for all from Codex and Claude](https://huggingface.co/blog/custom-cuda-kernels-agent-skills) blog post covers this in more detail.
For example, a custom RMSNorm kernel (generated by the `add cuda-kernels` skill) with [torch.compile](#torchcompile) speeds up LTX-Video generation 1.43x on an H100.
<iframe
src="https://huggingface.co/datasets/docs-benchmarks/kernel-ltx-video/embed/viewer/default/train"
frameborder="0"
width="100%"
height="560px"
></iframe>
## Dynamic quantization
[Dynamic quantization](https://pytorch.org/tutorials/recipes/recipes/dynamic_quantization.html) improves inference speed by reducing precision to enable faster math operations. This particular type of quantization determines how to scale the activations based on the data at runtime rather than using a fixed scaling factor. As a result, the scaling factor is more accurately aligned with the data.

View File

@@ -0,0 +1,50 @@
# Discrete Token Diffusion (Experimental)
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
## LLaDA2
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
### Train
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--text_column text \
--output_dir llada2-output \
--max_train_steps 1000 \
--prompt_length 32 \
--block_length 32 \
--lambda_conf 2.0 \
--conf_temperature 0.5
```
If you don't want to download a dataset, you can use random-token data:
```bash
accelerate launch examples/discrete_diffusion/train_llada2.py \
--model_name_or_path Qwen/Qwen2.5-0.5B \
--output_dir llada2-output \
--use_dummy_data \
--num_dummy_samples 2048
```
### Sample
```bash
python examples/discrete_diffusion/sample_llada2.py \
--model_id inclusionAI/LLaDA2.1-mini \
--prompt "Write a short poem about the ocean." \
--gen_length 256 \
--num_inference_steps 32 \
--threshold 0.7 \
--editing_threshold 0.5 \
--max_post_steps 16 \
--use_chat_template \
--add_generation_prompt
```

View File

@@ -0,0 +1,263 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample script for LLaDA2-style discrete diffusion text generation.
This script demonstrates how to use the LLaDA2Pipeline for text generation
using block-wise iterative refinement.
Example usage:
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
from diffusers.hooks import apply_group_offloading
def main():
parser = argparse.ArgumentParser(
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
)
parser.add_argument(
"--model_id",
type=str,
default="inclusionAI/LLaDA2.0-mini",
help="HuggingFace model ID or path to local model.",
)
parser.add_argument(
"--prompt",
type=str,
default="Why does Camus think that Sisyphus is happy?",
help="Text prompt to generate from.",
)
parser.add_argument(
"--gen_length",
type=int,
default=2048,
help="Number of tokens to generate.",
)
parser.add_argument(
"--block_length",
type=int,
default=32,
help="Size of each generation block.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=32,
help="Number of refinement steps per block.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (0.0 for greedy).",
)
parser.add_argument(
"--top_p",
type=float,
default=None,
help="Nucleus sampling probability threshold.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Top-k sampling parameter.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.95,
help="Confidence threshold for committing tokens.",
)
parser.add_argument(
"--editing_threshold",
type=float,
default=None,
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
)
parser.add_argument(
"--max_post_steps",
type=int,
default=0,
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
)
parser.add_argument(
"--sampling_method",
type=str,
default="multinomial",
choices=["auto", "greedy", "multinomial"],
help="Sampling method for block refinement.",
)
parser.add_argument(
"--eos_early_stop",
action="store_true",
help="Stop generation early when EOS token is generated.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--offload",
type=str,
default=None,
choices=["group", "sequential"],
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
)
args = parser.parse_args()
# Parse dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
# Load model with appropriate memory settings based on offload strategy
if args.offload == "group":
# For group offloading, load to CPU first then apply hooks
print("Using group offloading for memory efficiency...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Apply group offloading with CUDA streams for better performance
onload_device = torch.device(args.device)
offload_device = torch.device("cpu")
apply_group_offloading(
model,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
elif args.offload == "sequential":
# For sequential offloading, load to CPU first
print("Using sequential CPU offloading (slower but lower memory)...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Sequential offloading will be applied via pipeline
else:
# Default: use device_map="auto" for automatic memory management
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
revision=args.revision,
)
model.eval()
# Create pipeline
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
# Apply sequential CPU offload if requested
if args.offload == "sequential":
pipe.enable_sequential_cpu_offload()
# Set up generator for reproducibility
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device).manual_seed(args.seed)
print(f"\nPrompt: {args.prompt}")
print(
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
)
print("-" * 50)
# Generate
output = pipe(
prompt=args.prompt,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
gen_length=args.gen_length,
block_length=args.block_length,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
threshold=args.threshold,
editing_threshold=args.editing_threshold,
max_post_steps=args.max_post_steps,
sampling_method=args.sampling_method,
eos_early_stop=args.eos_early_stop,
generator=generator,
)
print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,321 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import math
import os
from dataclasses import asdict, dataclass
from typing import Dict, Optional
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
from diffusers import BlockRefinementScheduler
from diffusers.training_utils import compute_confidence_aware_loss
logger = get_logger(__name__)
@dataclass
class TrainConfig:
model_name_or_path: str
dataset_name: str
dataset_config_name: Optional[str]
text_column: str
cache_dir: Optional[str]
use_dummy_data: bool
num_dummy_samples: int
output_dir: str
seed: int
max_train_steps: int
checkpointing_steps: int
logging_steps: int
per_device_train_batch_size: int
gradient_accumulation_steps: int
learning_rate: float
weight_decay: float
lr_scheduler: str
lr_warmup_steps: int
max_length: int
prompt_length: int
block_length: int
lambda_conf: float
conf_temperature: float
def parse_args() -> TrainConfig:
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
parser.add_argument("--dataset_name", type=str, default="wikitext")
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
parser.add_argument("--text_column", type=str, default="text")
parser.add_argument("--cache_dir", type=str, default=None)
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
parser.add_argument("--num_dummy_samples", type=int, default=2048)
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--max_train_steps", type=int, default=1000)
parser.add_argument("--checkpointing_steps", type=int, default=500)
parser.add_argument("--logging_steps", type=int, default=50)
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--weight_decay", type=float, default=0.0)
parser.add_argument(
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
)
parser.add_argument("--lr_warmup_steps", type=int, default=100)
parser.add_argument("--max_length", type=int, default=256)
parser.add_argument("--prompt_length", type=int, default=32)
parser.add_argument("--block_length", type=int, default=32)
parser.add_argument("--lambda_conf", type=float, default=2.0)
parser.add_argument("--conf_temperature", type=float, default=0.5)
args = parser.parse_args()
return TrainConfig(**vars(args))
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
texts = examples[text_column]
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
class RandomTokenDataset(torch.utils.data.Dataset):
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
self.num_samples = int(num_samples)
self.seq_len = int(seq_len)
self.vocab_size = int(vocab_size)
self.pad_token_id = int(pad_token_id)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
del idx
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
return {"input_ids": input_ids, "attention_mask": attention_mask}
def main():
cfg = parse_args()
if cfg.prompt_length >= cfg.max_length:
raise ValueError("`prompt_length` must be < `max_length`.")
if cfg.block_length <= 0:
raise ValueError("`block_length` must be > 0.")
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
accelerator = Accelerator(
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
project_config=project_config,
)
if accelerator.is_main_process:
os.makedirs(cfg.output_dir, exist_ok=True)
accelerator.wait_for_everyone()
set_seed(cfg.seed)
logger.info("Training configuration: %s", asdict(cfg))
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
if tokenizer.mask_token_id is None:
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
model.resize_token_embeddings(len(tokenizer))
if load_dtype == torch.float32:
model.to(dtype=torch.float32)
mask_token_id = int(tokenizer.mask_token_id)
if cfg.use_dummy_data:
dataset = RandomTokenDataset(
num_samples=cfg.num_dummy_samples,
seq_len=cfg.max_length,
vocab_size=len(tokenizer),
pad_token_id=int(tokenizer.pad_token_id),
)
train_dataloader = DataLoader(
dataset,
shuffle=True,
batch_size=cfg.per_device_train_batch_size,
drop_last=True,
)
else:
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
if "train" not in raw_datasets:
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
with accelerator.main_process_first():
tokenized = raw_datasets["train"].map(
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
batched=True,
remove_columns=raw_datasets["train"].column_names,
desc="Tokenizing",
)
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
train_dataloader = DataLoader(
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
lr_scheduler = get_scheduler(
name=cfg.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=cfg.lr_warmup_steps,
num_training_steps=cfg.max_train_steps,
)
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
model, optimizer, train_dataloader, lr_scheduler
)
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
global_step = 0
model.train()
for _epoch in range(num_train_epochs):
for batch in train_dataloader:
with accelerator.accumulate(model):
input_ids = batch["input_ids"]
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=cfg.prompt_length,
block_length=cfg.block_length,
mask_token_id=mask_token_id,
generator=gen,
)
position_ids = (
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
)
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
logits_rev = model(
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
).logits
logits = logits.clone()
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
logits_rev = logits_rev.clone()
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
valid = attention_mask.to(dtype=torch.bool)
masked = masked & valid
masked_rev = masked_rev & valid
labels = input_ids.clone()
labels[~masked] = -100
labels_rev = input_ids.clone()
labels_rev[~masked_rev] = -100
weights = masked.to(dtype=logits.dtype)
weights_rev = masked_rev.to(dtype=logits.dtype)
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits,
labels,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights,
)
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
logits_rev,
labels_rev,
lambda_conf=cfg.lambda_conf,
temperature=cfg.conf_temperature,
per_token_weights=weights_rev,
)
total_loss = loss + loss_rev
accelerator.backward(total_loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
logger.info(
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
global_step,
total_loss.item(),
(loss_sft + loss_sft_rev).item(),
(loss_conf + loss_conf_rev).item(),
lr_scheduler.get_last_lr()[0],
)
print(
f"step={global_step} loss={total_loss.item():.4f} "
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
)
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
accelerator.wait_for_everyone()
if accelerator.is_main_process:
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
os.makedirs(save_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
tokenizer.save_pretrained(save_dir)
if global_step >= cfg.max_train_steps:
break
if global_step >= cfg.max_train_steps:
break
accelerator.wait_for_everyone()
if accelerator.is_main_process:
final_dir = os.path.join(cfg.output_dir, "final")
os.makedirs(final_dir, exist_ok=True)
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
tokenizer.save_pretrained(final_dir)
logger.info("Done.")
if __name__ == "__main__":
main()

View File

@@ -1105,7 +1105,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

View File

@@ -1251,7 +1251,7 @@ def main(args):
# text encoding.
captions = batch["captions"]
text_encoding_pipeline = text_encoding_pipeline.to("cuda")
text_encoding_pipeline = text_encoding_pipeline.to(accelerator.device)
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds, text_ids = text_encoding_pipeline.encode_prompt(
captions, prompt_2=None

168
profiling/PROFILING_PLAN.md Normal file
View File

@@ -0,0 +1,168 @@
# Profiling Plan: Diffusers Pipeline Profiling with torch.profiler
## Context
We want to uncover CPU overhead, CPU-GPU sync points, and other bottlenecks in popular diffusers pipelines — especially issues that become non-trivial under `torch.compile`. The approach is inspired by [flux-fast's run_benchmark.py](https://github.com/huggingface/flux-fast/blob/0a1dcc91658f0df14cd7fce862a5c8842784c6da/run_benchmark.py#L66-L85) which uses `torch.profiler` with method-level annotations, and motivated by issues like [diffusers#11696](https://github.com/huggingface/diffusers/pull/11696) (DtoH sync from scheduler `.item()` call).
## Target Pipelines
| Pipeline | Type | Checkpoint | Steps |
|----------|------|-----------|-------|
| `FluxPipeline` | text-to-image | `black-forest-labs/FLUX.1-dev` | 4 |
| `Flux2Pipeline` | text-to-image | `black-forest-labs/FLUX.2-dev` | 4 |
| `WanPipeline` | text-to-video | `Wan-AI/Wan2.1-T2V-14B-Diffusers` | 4 |
| `LTX2Pipeline` | text-to-video | `Lightricks/LTX-2` | 4 |
| `QwenImagePipeline` | text-to-image | `Qwen/Qwen-Image` | 4 |
## Approach
Follow the flux-fast pattern: **annotate key pipeline methods** with `torch.profiler.record_function` wrappers, then run the pipeline under `torch.profiler.profile` and export a Chrome trace.
### New Files
```
profiling/
profiling_utils.py # Annotation helper + profiler setup
profiling_pipelines.py # CLI entry point with pipeline configs
```
### Step 1: `profiling_utils.py` — Annotation and Profiler Infrastructure
**A) `annotate(func, name)` helper** (same pattern as flux-fast):
```python
def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)
return wrapper
```
**B) `annotate_pipeline(pipe)` function** — applies annotations to key methods on any pipeline:
- `pipe.transformer.forward``"transformer_forward"`
- `pipe.vae.decode``"vae_decode"` (if present)
- `pipe.vae.encode``"vae_encode"` (if present)
- `pipe.scheduler.step``"scheduler_step"`
- `pipe.encode_prompt``"encode_prompt"` (if present, for full-pipeline profiling)
This is non-invasive — it monkey-patches bound methods without modifying source.
**C) `PipelineProfiler` class:**
- `__init__(pipeline_config, output_dir, mode="eager"|"compile")`
- `setup_pipeline()` → loads from pretrained, optionally compiles transformer, calls `annotate_pipeline()`
- `run()`:
1. Warm up with 1 unannotated run
2. Profile 1 run with `torch.profiler.profile`:
- `activities=[CPU, CUDA]`
- `record_shapes=True`
- `profile_memory=True`
- `with_stack=True`
3. Export Chrome trace JSON
4. Print `key_averages()` summary table (sorted by CUDA time) to stdout
### Step 2: `profiling_pipelines.py` — CLI with Pipeline Configs
**Pipeline config registry** — each entry specifies:
- `pipeline_cls`, `pretrained_model_name_or_path`, `torch_dtype`
- `call_kwargs` with pipeline-specific defaults:
| Pipeline | Resolution | Frames | Steps | Extra |
|----------|-----------|--------|-------|-------|
| Flux | 1024x1024 | — | 4 | `guidance_scale=3.5` |
| Flux2 | 1024x1024 | — | 4 | `guidance_scale=3.5` |
| Wan | 480x832 | 81 | 4 | — |
| LTX2 | 768x512 | 121 | 4 | `guidance_scale=4.0` |
| QwenImage | 1024x1024 | — | 4 | `true_cfg_scale=4.0` |
All configs use `output_type="latent"` by default (skip VAE decode for cleaner denoising-loop traces).
**CLI flags:**
- `--pipeline flux|flux2|wan|ltx2|qwenimage|all`
- `--mode eager|compile|both`
- `--output_dir profiling_results/`
- `--num_steps N` (override, default 4)
- `--full_decode` (switch output_type from `"latent"` to `"pil"` to include VAE)
- `--compile_mode default|reduce-overhead|max-autotune`
- `--compile_fullgraph` flag
**Output:** `{output_dir}/{pipeline}_{mode}.json` Chrome trace + stdout summary.
### Step 3: Known Sync Issues to Validate
The profiling should surface these known/suspected issues:
1. **Scheduler DtoH sync via `nonzero().item()`** — For Flux, this was fixed by adding `scheduler.set_begin_index(0)` before the denoising loop ([diffusers#11696](https://github.com/huggingface/diffusers/pull/11696)). Profiling should reveal whether similar sync points exist in other pipelines.
2. **`modulate_index` tensor rebuilt every forward in `transformer_qwenimage.py`** (line 901-905) — Python list comprehension + `torch.tensor()` each step. Minor but visible in trace.
3. **Any other `.item()`, `.cpu()`, `.numpy()` calls** in the denoising loop hot path — the profiler's `with_stack=True` will surface these as CPU stalls with Python stack traces.
## Verification
1. Run: `python profiling/profiling_pipelines.py --pipeline flux --mode eager --num_steps 4`
2. Verify `profiling_results/flux_eager.json` is produced
3. Open trace in [Perfetto UI](https://ui.perfetto.dev/) — confirm:
- `transformer_forward` and `scheduler_step` annotations visible
- CPU and CUDA timelines present
- Stack traces visible on CPU events
4. Run with `--mode compile` and compare trace for fewer/fused CUDA kernels
## Interpreting Traces in Perfetto UI
Open the exported `.json` trace at [ui.perfetto.dev](https://ui.perfetto.dev/). The trace has two main rows: **CPU** (top) and **CUDA** (bottom).
### What to look for
**1. Gaps between CUDA kernels**
Zoom into the CUDA row during the denoising loop. Ideally, GPU kernels should be back-to-back with no gaps. Gaps mean the GPU is idle waiting for the CPU to launch the next kernel. Common causes:
- Python overhead between ops (visible as CPU slices in the CPU row during the gap)
- DtoH sync (`.item()`, `.cpu()`) forcing the GPU to drain before the CPU can proceed
**2. CPU stalls (DtoH syncs)**
Look for long CPU slices labeled `cudaStreamSynchronize` or `cudaDeviceSynchronize`. Click on them — if `with_stack=True` was enabled, the bottom panel shows the Python stack trace pointing to the exact line causing the sync (e.g., a `.item()` call in the scheduler).
**3. Annotated regions**
Our `record_function` annotations (`transformer_forward`, `scheduler_step`, etc.) appear as labeled spans on the CPU row. This lets you quickly:
- Measure how long each phase takes (click a span to see duration)
- See if `scheduler_step` is disproportionately expensive relative to `transformer_forward` (it should be negligible)
- Spot unexpected CPU work between annotated regions
**4. Eager vs compile comparison**
Open both traces side by side (two Perfetto tabs). Key differences to look for:
- **Fewer, wider CUDA kernels** in compile mode (fused ops) vs many small kernels in eager
- **Smaller CPU gaps** between kernels in compile mode (less Python dispatch overhead)
- **Graph breaks**: if compile mode still shows many small kernels in a section, that section likely has a graph break — check `TORCH_LOGS="+dynamo"` output for details
**5. Memory timeline**
In Perfetto, look for the memory counter track (if `profile_memory=True`). Spikes during the denoising loop suggest unexpected allocations per step. Steady-state memory during denoising is expected — growing memory is not.
**6. Kernel launch latency**
Each CUDA kernel is launched from the CPU. In Perfetto, you can see the CPU-side launch call (e.g., `cudaLaunchKernel`) and the corresponding GPU-side kernel execution. The time between the CPU dispatch and the GPU kernel starting should be minimal (single-digit microseconds). If you see consistent delays > 10-20us between launch and execution:
- The launch queue may be starved because of excessive Python work between ops
- There may be implicit syncs forcing serialization
- `torch.compile` should help here by batching launches — compare eager vs compile to confirm
To inspect this: zoom into a single denoising step, select a CUDA kernel on the GPU row, and look at the corresponding CPU-side launch slice directly above it. The horizontal offset between them is the launch latency. In a healthy trace, CPU launch slices should be well ahead of GPU execution (the CPU is "feeding" the GPU faster than it can consume).
### Quick checklist per pipeline
| Question | Where to look | Healthy | Unhealthy |
|----------|--------------|---------|-----------|
| GPU staying busy? | CUDA row gaps | Back-to-back kernels | Frequent gaps > 100us |
| CPU blocking on GPU? | `cudaStreamSynchronize` slices | Rare/absent during denoise | Present every step |
| Scheduler overhead? | `scheduler_step` span duration | < 1% of step time | > 5% of step time |
| Compile effective? | CUDA kernel count per step | Fewer large kernels | Same as eager |
| Kernel launch latency? | CPU launch → GPU kernel offset | < 10us, CPU ahead of GPU | > 20us or CPU trailing GPU |
| Memory stable? | Memory counter track | Flat during denoise loop | Growing per step |

View File

@@ -0,0 +1,182 @@
"""
Profile diffusers pipelines with torch.profiler.
Usage:
python profiling/profiling_pipelines.py --pipeline flux --mode eager
python profiling/profiling_pipelines.py --pipeline flux --mode compile
python profiling/profiling_pipelines.py --pipeline flux --mode both
python profiling/profiling_pipelines.py --pipeline all --mode eager
python profiling/profiling_pipelines.py --pipeline wan --mode eager --full_decode
python profiling/profiling_pipelines.py --pipeline flux --mode compile --num_steps 4
"""
import argparse
import copy
import logging
import torch
from profiling_utils import PipelineProfiler, PipelineProfilingConfig
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
PROMPT = "A cat holding a sign that says hello world"
def build_registry():
"""Build the pipeline config registry. Imports are deferred to avoid loading all pipelines upfront."""
from diffusers import FluxPipeline, Flux2Pipeline, WanPipeline, LTX2Pipeline, QwenImagePipeline
return {
"flux": PipelineProfilingConfig(
name="flux",
pipeline_cls=FluxPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"flux2": PipelineProfilingConfig(
name="flux2",
pipeline_cls=Flux2Pipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "black-forest-labs/FLUX.2-klein-base-9B",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"guidance_scale": 3.5,
"output_type": "latent",
},
),
"wan": PipelineProfilingConfig(
name="wan",
pipeline_cls=WanPipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards",
"height": 480,
"width": 832,
"num_frames": 81,
"num_inference_steps": 4,
"output_type": "latent",
},
),
"ltx2": PipelineProfilingConfig(
name="ltx2",
pipeline_cls=LTX2Pipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Lightricks/LTX-2",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": "worst quality, inconsistent motion, blurry, jittery, distorted",
"height": 512,
"width": 768,
"num_frames": 121,
"num_inference_steps": 4,
"guidance_scale": 4.0,
"output_type": "latent",
},
),
"qwenimage": PipelineProfilingConfig(
name="qwenimage",
pipeline_cls=QwenImagePipeline,
pipeline_init_kwargs={
"pretrained_model_name_or_path": "Qwen/Qwen-Image",
"torch_dtype": torch.bfloat16,
},
pipeline_call_kwargs={
"prompt": PROMPT,
"negative_prompt": " ",
"height": 1024,
"width": 1024,
"num_inference_steps": 4,
"true_cfg_scale": 4.0,
"output_type": "latent",
},
),
}
def main():
parser = argparse.ArgumentParser(description="Profile diffusers pipelines with torch.profiler")
parser.add_argument(
"--pipeline",
choices=["flux", "flux2", "wan", "ltx2", "qwenimage", "all"],
required=True,
help="Which pipeline to profile",
)
parser.add_argument(
"--mode",
choices=["eager", "compile", "both"],
default="eager",
help="Run in eager mode, compile mode, or both",
)
parser.add_argument("--output_dir", default="profiling_results", help="Directory for trace output")
parser.add_argument("--num_steps", type=int, default=None, help="Override num_inference_steps")
parser.add_argument("--full_decode", action="store_true", help="Profile including VAE decode (output_type='pil')")
parser.add_argument(
"--compile_mode",
default="default",
choices=["default", "reduce-overhead", "max-autotune"],
help="torch.compile mode",
)
parser.add_argument("--compile_fullgraph", action="store_true", help="Use fullgraph=True for torch.compile")
parser.add_argument(
"--compile_regional",
action="store_true",
help="Use compile_repeated_blocks() instead of full model compile",
)
args = parser.parse_args()
registry = build_registry()
pipeline_names = list(registry.keys()) if args.pipeline == "all" else [args.pipeline]
modes = ["eager", "compile"] if args.mode == "both" else [args.mode]
for pipeline_name in pipeline_names:
for mode in modes:
config = copy.deepcopy(registry[pipeline_name])
# Apply overrides
if args.num_steps is not None:
config.pipeline_call_kwargs["num_inference_steps"] = args.num_steps
if args.full_decode:
config.pipeline_call_kwargs["output_type"] = "pil"
if mode == "compile":
config.compile_kwargs = {
"fullgraph": args.compile_fullgraph,
"mode": args.compile_mode,
}
config.compile_regional = args.compile_regional
logger.info(f"Profiling {pipeline_name} in {mode} mode...")
profiler = PipelineProfiler(config, args.output_dir)
try:
trace_file = profiler.run()
logger.info(f"Done: {trace_file}")
except Exception as e:
logger.error(f"Failed to profile {pipeline_name} ({mode}): {e}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,143 @@
import functools
import gc
import logging
import os
from dataclasses import dataclass, field
from typing import Any
import torch
import torch.profiler
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
logger = logging.getLogger(__name__)
def annotate(func, name):
"""Wrap a function with torch.profiler.record_function for trace annotation."""
@functools.wraps(func)
def wrapper(*args, **kwargs):
with torch.profiler.record_function(name):
return func(*args, **kwargs)
return wrapper
def annotate_pipeline(pipe):
"""Apply profiler annotations to key pipeline methods.
Monkey-patches bound methods so they appear as named spans in the trace.
Non-invasive — no source modifications required.
"""
annotations = [
("transformer", "forward", "transformer_forward"),
("vae", "decode", "vae_decode"),
("vae", "encode", "vae_encode"),
("scheduler", "step", "scheduler_step"),
]
# Annotate sub-component methods
for component_name, method_name, label in annotations:
component = getattr(pipe, component_name, None)
if component is None:
continue
method = getattr(component, method_name, None)
if method is None:
continue
setattr(component, method_name, annotate(method, label))
# Annotate pipeline-level methods
if hasattr(pipe, "encode_prompt"):
pipe.encode_prompt = annotate(pipe.encode_prompt, "encode_prompt")
def flush():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
@dataclass
class PipelineProfilingConfig:
name: str
pipeline_cls: Any
pipeline_init_kwargs: dict[str, Any]
pipeline_call_kwargs: dict[str, Any]
compile_kwargs: dict[str, Any] | None = field(default=None)
compile_regional: bool = False
class PipelineProfiler:
def __init__(self, config: PipelineProfilingConfig, output_dir: str = "profiling_results"):
self.config = config
self.output_dir = output_dir
os.makedirs(output_dir, exist_ok=True)
def setup_pipeline(self):
"""Load the pipeline from pretrained, optionally compile, and annotate."""
logger.info(f"Loading pipeline: {self.config.name}")
pipe = self.config.pipeline_cls.from_pretrained(**self.config.pipeline_init_kwargs)
pipe.to("cuda")
if self.config.compile_kwargs:
if self.config.compile_regional:
logger.info(f"Regional compilation (compile_repeated_blocks) with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile_repeated_blocks(**self.config.compile_kwargs)
else:
logger.info(f"Full compilation with kwargs: {self.config.compile_kwargs}")
pipe.transformer.compile(**self.config.compile_kwargs)
annotate_pipeline(pipe)
return pipe
def run(self):
"""Execute the profiling run: warmup, then profile one pipeline call."""
pipe = self.setup_pipeline()
flush()
mode = "compile" if self.config.compile_kwargs else "eager"
trace_file = os.path.join(self.output_dir, f"{self.config.name}_{mode}.json")
# Warmup (pipeline __call__ is already decorated with @torch.no_grad())
logger.info("Running warmup...")
pipe(**self.config.pipeline_call_kwargs)
flush()
# Profile
logger.info("Running profiled iteration...")
activities = [
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
]
with torch.profiler.profile(
activities=activities,
record_shapes=True,
profile_memory=True,
with_stack=True,
) as prof:
with torch.profiler.record_function("pipeline_call"):
pipe(**self.config.pipeline_call_kwargs)
# Export trace
prof.export_chrome_trace(trace_file)
logger.info(f"Chrome trace saved to: {trace_file}")
# Print summary
print("\n" + "=" * 80)
print(f"Profile summary: {self.config.name} ({mode})")
print("=" * 80)
print(
prof.key_averages().table(
sort_by="cuda_time_total",
row_limit=20,
)
)
# Cleanup
pipe.to("cpu")
del pipe
flush()
return trace_file

39
profiling/run_profiling.sh Executable file
View File

@@ -0,0 +1,39 @@
#!/bin/bash
# Run profiling across all pipelines in eager and compile (regional) modes.
#
# Usage:
# bash profiling/run_profiling.sh
# bash profiling/run_profiling.sh --output_dir my_results
set -euo pipefail
OUTPUT_DIR="${1:-profiling_results}"
NUM_STEPS=2
PIPELINES=("flux" "flux2" "wan" "ltx2" "qwenimage")
MODES=("eager" "compile")
for pipeline in "${PIPELINES[@]}"; do
for mode in "${MODES[@]}"; do
echo "============================================================"
echo "Profiling: ${pipeline} | mode: ${mode}"
echo "============================================================"
COMPILE_ARGS=""
if [ "$mode" = "compile" ]; then
COMPILE_ARGS="--compile_regional --compile_fullgraph --compile_mode default"
fi
python profiling/profiling_pipelines.py \
--pipeline "$pipeline" \
--mode "$mode" \
--output_dir "$OUTPUT_DIR" \
--num_steps "$NUM_STEPS" \
$COMPILE_ARGS
echo ""
done
done
echo "============================================================"
echo "All traces saved to: ${OUTPUT_DIR}/"
echo "============================================================"

View File

@@ -344,6 +344,8 @@ else:
_import_structure["schedulers"].extend(
[
"AmusedScheduler",
"BlockRefinementScheduler",
"BlockRefinementSchedulerOutput",
"CMStochasticIterativeScheduler",
"CogVideoXDDIMScheduler",
"CogVideoXDPMScheduler",
@@ -580,6 +582,8 @@ else:
"LDMTextToImagePipeline",
"LEditsPPPipelineStableDiffusion",
"LEditsPPPipelineStableDiffusionXL",
"LLaDA2Pipeline",
"LLaDA2PipelineOutput",
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
@@ -1124,6 +1128,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .quantizers import DiffusersQuantizer
from .schedulers import (
AmusedScheduler,
BlockRefinementScheduler,
BlockRefinementSchedulerOutput,
CMStochasticIterativeScheduler,
CogVideoXDDIMScheduler,
CogVideoXDPMScheduler,
@@ -1339,6 +1345,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LLaDA2Pipeline,
LLaDA2PipelineOutput,
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,

View File

@@ -2443,6 +2443,191 @@ def _convert_non_diffusers_flux2_lora_to_diffusers(state_dict):
return converted_state_dict
def _convert_kohya_flux2_lora_to_diffusers(state_dict):
def _convert_to_ai_toolkit(sds_sd, ait_sd, sds_key, ait_key):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
# scale weight by alpha and dim
rank = down_weight.shape[0]
default_alpha = torch.tensor(rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha).item()
scale = alpha / rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
ait_sd[ait_key + ".lora_A.weight"] = down_weight * scale_down
ait_sd[ait_key + ".lora_B.weight"] = sds_sd.pop(sds_key + ".lora_up.weight") * scale_up
def _convert_to_ai_toolkit_cat(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
if sds_key + ".lora_down.weight" not in sds_sd:
return
down_weight = sds_sd.pop(sds_key + ".lora_down.weight")
up_weight = sds_sd.pop(sds_key + ".lora_up.weight")
sd_lora_rank = down_weight.shape[0]
default_alpha = torch.tensor(
sd_lora_rank, dtype=down_weight.dtype, device=down_weight.device, requires_grad=False
)
alpha = sds_sd.pop(sds_key + ".alpha", default_alpha)
scale = alpha / sd_lora_rank
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up
num_splits = len(ait_keys)
if dims is None:
dims = [up_weight.shape[0] // num_splits] * num_splits
else:
assert sum(dims) == up_weight.shape[0]
# check if upweight is sparse
is_sparse = False
if sd_lora_rank % num_splits == 0:
ait_rank = sd_lora_rank // num_splits
is_sparse = True
i = 0
for j in range(len(dims)):
for k in range(len(dims)):
if j == k:
continue
is_sparse = is_sparse and torch.all(
up_weight[i : i + dims[j], k * ait_rank : (k + 1) * ait_rank] == 0
)
i += dims[j]
if is_sparse:
logger.info(f"weight is sparse: {sds_key}")
ait_down_keys = [k + ".lora_A.weight" for k in ait_keys]
ait_up_keys = [k + ".lora_B.weight" for k in ait_keys]
if not is_sparse:
ait_sd.update(dict.fromkeys(ait_down_keys, down_weight))
ait_sd.update({k: v for k, v in zip(ait_up_keys, torch.split(up_weight, dims, dim=0))}) # noqa: C416
else:
ait_sd.update({k: v for k, v in zip(ait_down_keys, torch.chunk(down_weight, num_splits, dim=0))}) # noqa: C416
i = 0
for j in range(len(dims)):
ait_sd[ait_up_keys[j]] = up_weight[i : i + dims[j], j * ait_rank : (j + 1) * ait_rank].contiguous()
i += dims[j]
# Detect number of blocks from keys
num_double_layers = 0
num_single_layers = 0
for key in state_dict.keys():
if key.startswith("lora_unet_double_blocks_"):
block_idx = int(key.split("_")[4])
num_double_layers = max(num_double_layers, block_idx + 1)
elif key.startswith("lora_unet_single_blocks_"):
block_idx = int(key.split("_")[4])
num_single_layers = max(num_single_layers, block_idx + 1)
ait_sd = {}
for i in range(num_double_layers):
# Attention projections
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_out.0",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.to_q",
f"transformer.transformer_blocks.{i}.attn.to_k",
f"transformer.transformer_blocks.{i}.attn.to_v",
],
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_proj",
f"transformer.transformer_blocks.{i}.attn.to_add_out",
)
_convert_to_ai_toolkit_cat(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_attn_qkv",
[
f"transformer.transformer_blocks.{i}.attn.add_q_proj",
f"transformer.transformer_blocks.{i}.attn.add_k_proj",
f"transformer.transformer_blocks.{i}.attn.add_v_proj",
],
)
# MLP layers (Flux2 uses ff.linear_in/linear_out)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_0",
f"transformer.transformer_blocks.{i}.ff.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_img_mlp_2",
f"transformer.transformer_blocks.{i}.ff.linear_out",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_0",
f"transformer.transformer_blocks.{i}.ff_context.linear_in",
)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_double_blocks_{i}_txt_mlp_2",
f"transformer.transformer_blocks.{i}.ff_context.linear_out",
)
for i in range(num_single_layers):
# Single blocks: linear1 -> attn.to_qkv_mlp_proj (fused, no split needed)
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear1",
f"transformer.single_transformer_blocks.{i}.attn.to_qkv_mlp_proj",
)
# Single blocks: linear2 -> attn.to_out
_convert_to_ai_toolkit(
state_dict,
ait_sd,
f"lora_unet_single_blocks_{i}_linear2",
f"transformer.single_transformer_blocks.{i}.attn.to_out",
)
# Handle optional extra keys
extra_mappings = {
"lora_unet_img_in": "transformer.x_embedder",
"lora_unet_txt_in": "transformer.context_embedder",
"lora_unet_time_in_in_layer": "transformer.time_guidance_embed.timestep_embedder.linear_1",
"lora_unet_time_in_out_layer": "transformer.time_guidance_embed.timestep_embedder.linear_2",
"lora_unet_final_layer_linear": "transformer.proj_out",
}
for sds_key, ait_key in extra_mappings.items():
_convert_to_ai_toolkit(state_dict, ait_sd, sds_key, ait_key)
remaining_keys = list(state_dict.keys())
if remaining_keys:
logger.warning(f"Unsupported keys for Kohya Flux2 LoRA conversion: {remaining_keys}")
return ait_sd
def _convert_non_diffusers_z_image_lora_to_diffusers(state_dict):
"""
Convert non-diffusers ZImage LoRA state dict to diffusers format.

View File

@@ -43,6 +43,7 @@ from .lora_conversion_utils import (
_convert_bfl_flux_control_lora_to_diffusers,
_convert_fal_kontext_lora_to_diffusers,
_convert_hunyuan_video_lora_to_diffusers,
_convert_kohya_flux2_lora_to_diffusers,
_convert_kohya_flux_lora_to_diffusers,
_convert_musubi_wan_lora_to_diffusers,
_convert_non_diffusers_flux2_lora_to_diffusers,
@@ -5673,6 +5674,13 @@ class Flux2LoraLoaderMixin(LoraBaseMixin):
logger.warning(warn_msg)
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
is_kohya = any(".lora_down.weight" in k for k in state_dict)
if is_kohya:
state_dict = _convert_kohya_flux2_lora_to_diffusers(state_dict)
# Kohya already takes care of scaling the LoRA parameters with alpha.
out = (state_dict, metadata) if return_lora_metadata else state_dict
return out
is_peft_format = any(k.startswith("base_model.model.") for k in state_dict)
if is_peft_format:
state_dict = {k.replace("base_model.model.", "diffusion_model."): v for k, v in state_dict.items()}

View File

@@ -15,6 +15,7 @@
import inspect
import json
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Literal
@@ -44,33 +45,13 @@ from .unet_loader_utils import _maybe_expand_lora_scales
logger = logging.get_logger(__name__)
_SET_ADAPTER_SCALE_FN_MAPPING = {
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
"SD3Transformer2DModel": lambda model_cls, weights: weights,
"FluxTransformer2DModel": lambda model_cls, weights: weights,
"CogVideoXTransformer3DModel": lambda model_cls, weights: weights,
"ConsisIDTransformer3DModel": lambda model_cls, weights: weights,
"HeliosTransformer3DModel": lambda model_cls, weights: weights,
"MochiTransformer3DModel": lambda model_cls, weights: weights,
"HunyuanVideoTransformer3DModel": lambda model_cls, weights: weights,
"LTXVideoTransformer3DModel": lambda model_cls, weights: weights,
"SanaTransformer2DModel": lambda model_cls, weights: weights,
"AuraFlowTransformer2DModel": lambda model_cls, weights: weights,
"Lumina2Transformer2DModel": lambda model_cls, weights: weights,
"WanTransformer3DModel": lambda model_cls, weights: weights,
"CogView4Transformer2DModel": lambda model_cls, weights: weights,
"HiDreamImageTransformer2DModel": lambda model_cls, weights: weights,
"HunyuanVideoFramepackTransformer3DModel": lambda model_cls, weights: weights,
"WanVACETransformer3DModel": lambda model_cls, weights: weights,
"ChromaTransformer2DModel": lambda model_cls, weights: weights,
"ChronoEditTransformer3DModel": lambda model_cls, weights: weights,
"QwenImageTransformer2DModel": lambda model_cls, weights: weights,
"Flux2Transformer2DModel": lambda model_cls, weights: weights,
"ZImageTransformer2DModel": lambda model_cls, weights: weights,
"LTX2VideoTransformer3DModel": lambda model_cls, weights: weights,
"LTX2TextConnectors": lambda model_cls, weights: weights,
}
_SET_ADAPTER_SCALE_FN_MAPPING = defaultdict(
lambda: (lambda model_cls, weights: weights),
{
"UNet2DConditionModel": _maybe_expand_lora_scales,
"UNetMotionModel": _maybe_expand_lora_scales,
},
)
class PeftAdapterMixin:

View File

@@ -788,9 +788,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
freqs_cis = pad_sequence(freqs_cis, batch_first=True, padding_value=0.0)[:, : feats.shape[1]]
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
if all(seq == max_seqlen for seq in item_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(item_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None
@@ -871,9 +874,12 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr
unified_freqs = pad_sequence(unified_freqs, batch_first=True, padding_value=0.0)
# Attention mask
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
if all(seq == max_seqlen for seq in unified_seqlens):
attn_mask = None
else:
attn_mask = torch.zeros((bsz, max_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_seqlens):
attn_mask[i, :seq_len] = 1
# Noise mask
noise_mask_tensor = None

View File

@@ -285,6 +285,7 @@ else:
]
)
_import_structure["latte"] = ["LattePipeline"]
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
_import_structure["ltx"] = [
"LTXPipeline",
"LTXImageToVideoPipeline",
@@ -728,6 +729,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
)
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
from .ltx import (
LTXConditionPipeline,

View File

@@ -0,0 +1,47 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_import_structure = {}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,491 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Callable
import torch
from tqdm.auto import tqdm
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...schedulers import BlockRefinementScheduler
from ...utils import BaseOutput, logging, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
logger = logging.get_logger(__name__)
EXAMPLE_DOC_STRING = """
Examples:
```python
>>> import torch
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
>>> model_id = "inclusionAI/LLaDA2.1-mini"
>>> model = AutoModelForCausalLM.from_pretrained(
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
... )
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
>>> scheduler = BlockRefinementScheduler()
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
>>> print(output.texts[0])
```
"""
@dataclass
class LLaDA2PipelineOutput(BaseOutput):
sequences: torch.LongTensor
texts: list[str] | None = None
class LLaDA2Pipeline(DiffusionPipeline):
r"""
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
vocab_size]`.
"""
model: Any
scheduler: BlockRefinementScheduler
tokenizer: Any
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
def __init__(
self,
model: Any,
scheduler: BlockRefinementScheduler,
tokenizer: Any | None = None,
):
super().__init__()
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
@property
def num_timesteps(self):
return self._num_timesteps
# --- Prompt encoding ---
def _prepare_input_ids(
self,
*,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
use_chat_template: bool,
add_generation_prompt: bool,
chat_template_kwargs: dict[str, Any] | None,
) -> torch.LongTensor:
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
if input_ids is not None:
if input_ids.ndim == 1:
input_ids = input_ids.unsqueeze(0)
if input_ids.ndim != 2:
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
return input_ids
if self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and prompt is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if messages is None and prompt is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
chat_template_kwargs = chat_template_kwargs or {}
if messages is not None:
encoded = self.tokenizer.apply_chat_template(
messages,
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
if isinstance(prompt, list):
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
encoded = self.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=add_generation_prompt,
tokenize=True,
return_tensors="pt",
return_dict=True,
**chat_template_kwargs,
)
return encoded["input_ids"]
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
return encoded["input_ids"]
def check_inputs(
self,
prompt: str | list[str] | None,
messages: list[dict[str, str]] | None,
input_ids: torch.LongTensor | None,
gen_length: int,
block_length: int,
num_inference_steps: int,
minimal_topk: int,
threshold: float,
sampling_method: str,
output_type: str,
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
callback_on_step_end_tensor_inputs: list[str] | None,
):
# Input source validation
if prompt is None and messages is None and input_ids is None:
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
if prompt is not None and messages is not None:
raise ValueError("Provide either `prompt` or `messages`, not both.")
if input_ids is not None:
if input_ids.ndim not in (1, 2):
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
if input_ids.dtype != torch.long:
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
if prompt is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
if messages is not None and input_ids is None and self.tokenizer is None:
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
# Generation parameter validation
if gen_length <= 0:
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
if block_length <= 0:
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
if minimal_topk <= 0:
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
if sampling_method not in {"auto", "greedy", "multinomial"}:
raise ValueError(
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
)
if output_type not in {"seq", "text"}:
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
# Callback validation
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: str | list[str] | None = None,
messages: list[dict[str, str]] | None = None,
input_ids: torch.LongTensor | None = None,
use_chat_template: bool = True,
add_generation_prompt: bool = True,
gen_length: int = 2048,
block_length: int = 32,
num_inference_steps: int = 32,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "multinomial",
threshold: float = 0.7,
editing_threshold: float | None = 0.5,
max_post_steps: int = 16,
minimal_topk: int = 1,
eos_early_stop: bool = True,
eos_token_id: int | None = None,
mask_token_id: int | None = None,
generator: torch.Generator | None = None,
output_type: str = "text",
return_dict: bool = True,
callback_on_step_end: Callable[[int, int, dict], None]
| PipelineCallback
| MultiPipelineCallbacks
| None = None,
callback_on_step_end_tensor_inputs: list[str] | None = None,
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
"""
Generate text with block-wise refinement.
Args:
prompt (`str` or `List[str]`, *optional*):
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
available, the prompt is wrapped in a chat message before tokenization.
messages (`List[Dict[str, str]]`, *optional*):
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
when provided. Requires a tokenizer with `apply_chat_template`.
input_ids (`torch.LongTensor`, *optional*):
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
use_chat_template (`bool`, defaults to `True`):
Whether to wrap the prompt in a chat template.
add_generation_prompt (`bool`, defaults to `True`):
Whether to add the generation prompt when using chat templates.
gen_length (`int`):
Number of tokens to generate.
block_length (`int`):
Block size for refinement.
num_inference_steps (`int`):
Number of refinement steps per block.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`):
Confidence threshold for committing tokens.
editing_threshold (`float`, *optional*):
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
negative value to disable editing. Defaults to `0.5`.
max_post_steps (`int`):
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
used when `editing_threshold` is enabled. Defaults to `16`.
minimal_topk (`int`):
Minimum number of tokens to commit per step.
eos_early_stop (`bool`):
Whether to stop after committing EOS in a block.
eos_token_id (`int`, *optional*):
EOS token ID to use for early stopping.
mask_token_id (`int`, *optional*):
Mask token ID to use for the template.
generator (`torch.Generator`, *optional*):
RNG for sampling.
output_type (`str`, defaults to `"text"`):
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
token ID sequences only.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
timestep: int, callback_kwargs: Dict)`.
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
`confidence`, `active_block`.
Examples:
"""
# 1. Check inputs early
if callback_on_step_end is not None and isinstance(
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
if callback_on_step_end_tensor_inputs is None:
callback_on_step_end_tensor_inputs = ["block_x"]
self.check_inputs(
prompt=prompt,
messages=messages,
input_ids=input_ids,
gen_length=gen_length,
block_length=block_length,
num_inference_steps=num_inference_steps,
minimal_topk=minimal_topk,
threshold=threshold,
sampling_method=sampling_method,
output_type=output_type,
callback_on_step_end=callback_on_step_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
# 2. Prepare input IDs from prompt/messages/input_ids
prompt_ids = self._prepare_input_ids(
prompt=prompt,
messages=messages,
input_ids=input_ids,
use_chat_template=use_chat_template,
add_generation_prompt=add_generation_prompt,
chat_template_kwargs=None,
)
device = self._execution_device
if prompt_ids.ndim == 1:
prompt_ids = prompt_ids.unsqueeze(0)
prompt_ids = prompt_ids.to(device=device)
batch_size, prompt_length = prompt_ids.shape
if eos_token_id is None:
eos_token_id = self.eos_token_id
if mask_token_id is None:
mask_token_id = self.mask_token_id
if mask_token_id is None:
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
self.scheduler.set_timesteps(num_inference_steps, device=device)
# 3. Build attention mask and position IDs
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
total_length = num_blocks * block_length
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
# 4. Prepare latents (fully masked sequence)
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
if prompt_length > 0:
x[:, :prompt_length] = prompt_ids
prefill_blocks = prompt_length // block_length
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
global_step = 0
# 5. Block-wise refinement loop
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
block_progress_bar_config["position"] = 0
block_progress_bar_config["desc"] = "Blocks"
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
current_window_end = (num_block + 1) * block_length
block_x = x[:, :current_window_end]
block_attn_mask = attn_mask[:, :current_window_end]
block_position_ids = position_ids[:, :current_window_end]
# Identify which positions in the block are prompt (non-editable).
block_start_pos = num_block * block_length
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
if block_start_pos < prompt_length:
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
prompt_mask_in_block[:prompt_end_in_block] = True
post_steps = 0
step_idx = 0
should_continue = True
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
progress_bar = self.progress_bar(total=num_inference_steps)
while should_continue:
block_tokens = block_x[:, -block_length:]
masks_remaining = (block_tokens == mask_token_id).any()
if not masks_remaining:
post_steps += 1
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
block_logits = logits[:, -block_length:, :]
scheduler_output = self.scheduler.step(
model_output=block_logits,
timestep=step_idx,
sample=block_tokens,
mask_token_id=mask_token_id,
temperature=temperature,
top_p=top_p,
top_k=top_k,
sampling_method=sampling_method,
threshold=threshold,
editing_threshold=editing_threshold,
minimal_topk=minimal_topk,
prompt_mask=prompt_mask_in_block,
generator=generator,
return_dict=True,
)
transfer_index = scheduler_output.transfer_index
editing_transfer_index = scheduler_output.editing_transfer_index
final_transfer = transfer_index | editing_transfer_index
if final_transfer.any():
block_x[:, -block_length:] = scheduler_output.prev_sample
if eos_early_stop and eos_token_id is not None:
finished = self.scheduler.check_eos_finished(
cur_x=block_x,
sampled_tokens=scheduler_output.sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_token_id,
mask_token_id=mask_token_id,
prompt_length=prompt_length,
)
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
block_x = callback_outputs.pop("block_x", block_x)
global_step += 1
if masks_remaining:
step_idx += 1
progress_bar.update(1)
should_continue = self.scheduler.check_block_should_continue(
step_idx=step_idx,
masks_remaining=masks_remaining,
editing_enabled=editing_enabled,
editing_transfer_index=editing_transfer_index,
post_steps=post_steps,
max_post_steps=max_post_steps,
finished=finished,
)
progress_bar.close()
x[:, :current_window_end] = block_x
if eos_early_stop and finished.all():
break
# 6. Post-process output
generated = x[:, : prompt_length + gen_length]
sequences = generated[:, prompt_length:]
if eos_token_id is not None and batch_size == 1:
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
if len(eos_positions) > 0:
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
texts = None
if output_type == "text" and self.tokenizer is not None:
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
if not return_dict:
return sequences.to(device=device), texts
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]

View File

@@ -40,6 +40,7 @@ except OptionalDependencyNotAvailable:
else:
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -145,6 +146,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
else:
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
from .scheduling_amused import AmusedScheduler
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
from .scheduling_consistency_models import CMStochasticIterativeScheduler
from .scheduling_ddim import DDIMScheduler

View File

@@ -0,0 +1,460 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class BlockRefinementSchedulerOutput(BaseOutput):
"""
Output class for block refinement scheduling.
Args:
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Updated block tokens after the current refinement step.
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were committed (mask-filling).
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Boolean mask indicating which tokens were edited (non-mask replacement).
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Sampled token IDs from the model logits.
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
Probabilities of the sampled tokens.
"""
prev_sample: torch.LongTensor
transfer_index: torch.BoolTensor
editing_transfer_index: torch.BoolTensor
sampled_tokens: torch.LongTensor
sampled_probs: torch.Tensor
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
"""
Scheduler for block-wise iterative refinement (commit-by-confidence).
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
the number of refinement steps.
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
"""
order = 1
@register_to_config
def __init__(
self,
block_length: int = 32,
num_inference_steps: int = 32,
threshold: float = 0.95,
editing_threshold: float | None = None,
minimal_topk: int = 1,
):
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
self._transfer_schedule: torch.LongTensor | None = None
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
if num_inference_steps <= 0:
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
self.num_inference_steps = num_inference_steps
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
device=device if device is not None else "cpu"
)
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
if num_inference_steps <= 0:
return torch.zeros((0,), dtype=torch.long)
base = block_length // num_inference_steps
remainder = block_length % num_inference_steps
out = torch.full((num_inference_steps,), base, dtype=torch.long)
out[:remainder] += 1
return out
# --- SAR sampling utilities ---
@staticmethod
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
"""Nucleus (top-p) logit filtering."""
if top_p is None or top_p >= 1.0:
return logits
if not (0.0 < top_p <= 1.0):
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
sorted_probs = torch.softmax(sorted_logits, dim=-1)
cumulative_probs = sorted_probs.cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs > float(top_p)
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
return filtered
@staticmethod
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
"""Top-k logit filtering."""
if top_k is None or top_k <= 0:
return logits
if top_k >= logits.shape[-1]:
return logits
values, _ = torch.topk(logits, k=top_k, dim=-1)
min_keep = values[..., -1, None]
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
@staticmethod
def _sample_from_logits(
logits: torch.Tensor,
*,
temperature: float,
top_k: int | None,
top_p: float | None,
generator: torch.Generator | None,
use_multinomial: bool,
) -> tuple[torch.LongTensor, torch.Tensor]:
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
if temperature < 0:
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
vocab_size = logits.shape[-1]
flat_logits = logits.reshape(-1, vocab_size)
if temperature == 0.0 or not use_multinomial:
probs = torch.softmax(flat_logits.float(), dim=-1)
token = flat_logits.argmax(dim=-1, keepdim=True)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
scaled = flat_logits
if temperature != 1.0:
scaled = flat_logits / temperature
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
probs = torch.softmax(filtered.float(), dim=-1)
token = torch.multinomial(probs, num_samples=1, generator=generator)
token_prob = torch.gather(probs, -1, token)
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
def step(
self,
model_output: torch.Tensor,
timestep: int | torch.Tensor,
sample: torch.LongTensor,
*,
mask_token_id: int,
temperature: float = 0.0,
top_p: float | None = None,
top_k: int | None = None,
sampling_method: str = "auto",
threshold: float | None = None,
editing_threshold: float | None = None,
minimal_topk: int | None = None,
prompt_mask: torch.BoolTensor | None = None,
generator: torch.Generator | None = None,
return_dict: bool = True,
) -> (
BlockRefinementSchedulerOutput
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
):
"""
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
ones.
Args:
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
Raw logits from the model for the current block.
timestep (`int` or `torch.Tensor`):
Current step index within the block's refinement schedule.
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
Current block token IDs (contains mask tokens for uncommitted positions).
mask_token_id (`int`):
Token ID used for masked positions.
temperature (`float`):
Sampling temperature.
top_p (`float`, *optional*):
Nucleus sampling cutoff.
top_k (`int`, *optional*):
Top-k sampling cutoff.
sampling_method (`str`):
Sampling method (`auto`, `greedy`, `multinomial`).
threshold (`float`, *optional*):
Confidence threshold for committing tokens. Defaults to config value.
editing_threshold (`float`, *optional*):
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
config value.
minimal_topk (`int`, *optional*):
Minimum tokens to commit per step. Defaults to config value.
prompt_mask (`torch.BoolTensor`, *optional*):
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
generator (`torch.Generator`, *optional*):
RNG for sampling.
return_dict (`bool`):
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
"""
if threshold is None:
threshold = float(self.config.threshold)
if editing_threshold is None:
editing_threshold = self.config.editing_threshold
if minimal_topk is None:
minimal_topk = self.config.minimal_topk
# Sample from logits
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
sampled_tokens, sampled_probs = self._sample_from_logits(
model_output,
temperature=temperature,
top_k=top_k,
top_p=top_p,
generator=generator,
use_multinomial=use_multinomial,
)
batch_size, block_length = sample.shape
active_block = sample == mask_token_id
masks_remaining = active_block.any()
if isinstance(timestep, torch.Tensor):
step_index = int(timestep.item())
else:
step_index = int(timestep)
# --- Mask-filling transfer ---
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if masks_remaining and self._transfer_schedule is not None:
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
confidence = torch.where(
active_block,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
for b in range(batch_size):
high_conf = confidence[b] > threshold
if high_conf.sum().item() >= num_to_transfer:
transfer_index[b] = high_conf
else:
k = min(num_to_transfer, int(active_block[b].sum().item()))
if k > 0:
_, idx = torch.topk(confidence[b], k=k)
transfer_index[b, idx] = True
# --- Editing transfer (non-mask, non-prompt positions) ---
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
if editing_enabled:
if prompt_mask is None:
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
editing_conf = torch.where(
editable,
sampled_probs.to(dtype=torch.float32),
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
)
high_conf_edit = editing_conf > float(editing_threshold)
token_changed = sampled_tokens != sample
editing_transfer_index = high_conf_edit & token_changed & editable
# Apply transfers
final_transfer = transfer_index | editing_transfer_index
prev_sample = sample.clone()
if final_transfer.any():
prev_sample[final_transfer] = sampled_tokens[final_transfer]
if not return_dict:
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
return BlockRefinementSchedulerOutput(
prev_sample=prev_sample,
transfer_index=transfer_index,
editing_transfer_index=editing_transfer_index,
sampled_tokens=sampled_tokens,
sampled_probs=sampled_probs,
)
@staticmethod
def check_eos_finished(
cur_x: torch.LongTensor,
sampled_tokens: torch.LongTensor,
final_transfer: torch.BoolTensor,
finished: torch.BoolTensor,
eos_token_id: int,
mask_token_id: int,
prompt_length: int,
) -> torch.BoolTensor:
"""
Update per-batch finished flags when EOS tokens are committed.
Args:
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Current full sequence including all blocks up to the current window.
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
Tokens sampled by the scheduler in this step.
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
Combined mask of committed and edited positions.
finished (`torch.BoolTensor` of shape `(batch_size,)`):
Current per-batch finished flags.
eos_token_id (`int`):
EOS token ID.
mask_token_id (`int`):
Mask token ID.
prompt_length (`int`):
Number of prompt tokens at the start of the sequence.
Returns:
`torch.BoolTensor`: Updated finished flags.
"""
batch_size = cur_x.shape[0]
for b in range(batch_size):
if finished[b]:
continue
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
if not eos_in_commits:
continue
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
if len(eos_pos[0]) == 0:
continue
eos_pos = int(eos_pos[0][0].item())
if prompt_length >= eos_pos:
continue
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
finished[b] = True
return finished
def check_block_should_continue(
self,
step_idx: int,
masks_remaining: bool,
editing_enabled: bool,
editing_transfer_index: torch.BoolTensor,
post_steps: int,
max_post_steps: int,
finished: torch.BoolTensor,
) -> bool:
"""
Determine whether the inner refinement loop should continue for the current block.
Args:
step_idx (`int`):
Current refinement step index within this block.
masks_remaining (`bool`):
Whether any mask tokens remain in the block.
editing_enabled (`bool`):
Whether editing mode is active.
editing_transfer_index (`torch.BoolTensor`):
Which tokens were edited in this step.
post_steps (`int`):
Number of post-mask editing steps taken so far.
max_post_steps (`int`):
Maximum allowed post-mask editing steps.
finished (`torch.BoolTensor`):
Per-batch finished flags (from EOS detection).
Returns:
`bool`: `True` if refinement should continue, `False` to break.
"""
if finished.all():
return False
if not masks_remaining and not editing_enabled:
return False
if not masks_remaining and not editing_transfer_index.any():
return False
if masks_remaining and step_idx >= self.num_inference_steps:
return False
if not masks_remaining and post_steps > max_post_steps:
return False
return True
def add_noise(
self,
original_samples: torch.LongTensor,
attention_mask: torch.LongTensor,
*,
prompt_length: int,
block_length: int,
mask_token_id: int,
generator: torch.Generator | None = None,
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
"""
Apply the forward (noising) process for semi-autoregressive block masking.
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
one are the unmasked positions in the other.
Args:
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Clean token IDs.
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
Padding mask (1 for valid, 0 for padding).
prompt_length (`int`):
Number of leading prompt tokens to keep unmasked.
block_length (`int`):
Block size for masking.
mask_token_id (`int`):
Token ID to use for masked positions.
generator (`torch.Generator`, *optional*):
RNG for reproducibility.
Returns:
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
corresponding boolean masks.
"""
batch_size, seq_len = original_samples.shape
device = original_samples.device
noisy = original_samples.clone()
noisy_rev = original_samples.clone()
masked = torch.zeros_like(original_samples, dtype=torch.bool)
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
valid = attention_mask.to(dtype=torch.bool)
for block_start in range(prompt_length, seq_len, block_length):
block_end = min(seq_len, block_start + block_length)
seg_len = block_end - block_start
if seg_len <= 0:
continue
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
seg = seg & valid[:, block_start:block_end]
seg_rev = (~seg) & valid[:, block_start:block_end]
masked[:, block_start:block_end] = seg
masked_rev[:, block_start:block_end] = seg_rev
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
return noisy, noisy_rev, masked, masked_rev
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]

View File

@@ -11,6 +11,7 @@ from typing import Any, Iterable
import numpy as np
import torch
import torch.nn.functional as F
if getattr(torch, "distributed", None) is not None:
@@ -109,6 +110,92 @@ def compute_snr(noise_scheduler, timesteps):
return snr
def compute_confidence_aware_loss(
logits: torch.Tensor,
labels: torch.Tensor,
*,
lambda_conf: float = 0.0,
temperature: float = 1.0,
per_token_weights: torch.Tensor | None = None,
ignore_index: int = -100,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Computes a confidence-aware training loss for token classification-style heads.
This loss combines:
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
Args:
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
are excluded from both losses.
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
sharpen the distribution and change the strength of the confidence gradients.
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
Returns:
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
"""
if logits.ndim < 2:
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
if labels.shape != logits.shape[:-1]:
raise ValueError(
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
)
if temperature <= 0:
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
valid = labels.ne(ignore_index)
if per_token_weights is None:
weights = torch.ones_like(labels, dtype=logits.dtype)
else:
if per_token_weights.shape != labels.shape:
raise ValueError(
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
)
weights = per_token_weights.to(dtype=logits.dtype)
# Supervised CE (optionally weighted).
vocab_size = logits.shape[-1]
per_token_nll = F.cross_entropy(
logits.reshape(-1, vocab_size),
labels.reshape(-1),
reduction="none",
ignore_index=ignore_index,
).reshape_as(labels)
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
# Confidence loss: penalize entropy only where prediction is already correct.
if lambda_conf == 0.0:
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
return loss_sft, loss_sft, loss_conf
with torch.no_grad():
pred = logits.argmax(dim=-1)
correct = valid & pred.eq(labels)
scaled_logits = logits.float()
if temperature != 1.0:
scaled_logits = scaled_logits / float(temperature)
probs = torch.softmax(scaled_logits, dim=-1)
eps = torch.finfo(probs.dtype).tiny
log_probs = torch.log(probs.clamp_min(eps))
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
loss = loss_sft + float(lambda_conf) * loss_conf
return loss, loss_sft, loss_conf
def resolve_interpolation_mode(interpolation_type: str):
"""
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The

View File

@@ -2518,6 +2518,36 @@ class AmusedScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class BlockRefinementScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class CMStochasticIterativeScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -2222,6 +2222,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2Pipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LLaDA2PipelineOutput(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class LongCatImageEditPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,242 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from diffusers.modular_pipelines import (
AutoPipelineBlocks,
ConditionalPipelineBlocks,
InputParam,
ModularPipelineBlocks,
)
class TextToImageBlock(ModularPipelineBlocks):
model_name = "text2img"
@property
def inputs(self):
return [InputParam(name="prompt")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "text-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "text2img"
self.set_block_state(state, block_state)
return components, state
class ImageToImageBlock(ModularPipelineBlocks):
model_name = "img2img"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "image-to-image workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "img2img"
self.set_block_state(state, block_state)
return components, state
class InpaintBlock(ModularPipelineBlocks):
model_name = "inpaint"
@property
def inputs(self):
return [InputParam(name="prompt"), InputParam(name="image"), InputParam(name="mask")]
@property
def intermediate_outputs(self):
return []
@property
def description(self):
return "inpaint workflow"
def __call__(self, components, state):
block_state = self.get_block_state(state)
block_state.workflow = "inpaint"
self.set_block_state(state, block_state)
return components, state
class ConditionalImageBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = "text2img"
@property
def description(self):
return "Conditional image blocks for testing"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None # falls back to default_block_name
class OptionalConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock]
block_names = ["inpaint", "img2img"]
block_trigger_inputs = ["mask", "image"]
default_block_name = None # no default; block can be skipped
@property
def description(self):
return "Optional conditional blocks (skippable)"
def select_block(self, mask=None, image=None) -> str | None:
if mask is not None:
return "inpaint"
if image is not None:
return "img2img"
return None
class AutoImageBlocks(AutoPipelineBlocks):
block_classes = [InpaintBlock, ImageToImageBlock, TextToImageBlock]
block_names = ["inpaint", "img2img", "text2img"]
block_trigger_inputs = ["mask", "image", None]
@property
def description(self):
return "Auto image blocks for testing"
class TestConditionalPipelineBlocksSelectBlock:
def test_select_block_with_mask(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="something") == "inpaint"
def test_select_block_with_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(image="something") == "img2img"
def test_select_block_with_mask_and_image(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
def test_select_block_no_triggers_returns_none(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block() is None
def test_select_block_explicit_none_values(self):
blocks = ConditionalImageBlocks()
assert blocks.select_block(mask=None, image=None) is None
class TestConditionalPipelineBlocksWorkflowSelection:
def test_default_workflow_when_no_triggers(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks()
assert execution is not None
assert isinstance(execution, TextToImageBlock)
def test_mask_trigger_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_image_trigger_selects_img2img(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
def test_mask_and_image_selects_inpaint(self):
blocks = ConditionalImageBlocks()
execution = blocks.get_execution_blocks(mask=True, image=True)
assert isinstance(execution, InpaintBlock)
def test_skippable_block_returns_none(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks()
assert execution is None
def test_skippable_block_still_selects_when_triggered(self):
blocks = OptionalConditionalBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestAutoPipelineBlocksSelectBlock:
def test_auto_select_mask(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m") == "inpaint"
def test_auto_select_image(self):
blocks = AutoImageBlocks()
assert blocks.select_block(image="i") == "img2img"
def test_auto_select_default(self):
blocks = AutoImageBlocks()
# No trigger -> returns None -> falls back to default (text2img)
assert blocks.select_block() is None
def test_auto_select_priority_order(self):
blocks = AutoImageBlocks()
assert blocks.select_block(mask="m", image="i") == "inpaint"
class TestAutoPipelineBlocksWorkflowSelection:
def test_auto_default_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks()
assert isinstance(execution, TextToImageBlock)
def test_auto_mask_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(mask=True)
assert isinstance(execution, InpaintBlock)
def test_auto_image_workflow(self):
blocks = AutoImageBlocks()
execution = blocks.get_execution_blocks(image=True)
assert isinstance(execution, ImageToImageBlock)
class TestConditionalPipelineBlocksStructure:
def test_block_names_accessible(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert set(sub.keys()) == {"inpaint", "img2img", "text2img"}
def test_sub_block_types(self):
blocks = ConditionalImageBlocks()
sub = dict(blocks.sub_blocks)
assert isinstance(sub["inpaint"], InpaintBlock)
assert isinstance(sub["img2img"], ImageToImageBlock)
assert isinstance(sub["text2img"], TextToImageBlock)
def test_description(self):
blocks = ConditionalImageBlocks()
assert "Conditional" in blocks.description

View File

@@ -10,11 +10,6 @@ from huggingface_hub import hf_hub_download
import diffusers
from diffusers import AutoModel, ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.guiders import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
ConditionalPipelineBlocks,
LoopSequentialPipelineBlocks,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
@@ -25,7 +20,6 @@ from diffusers.modular_pipelines.modular_pipeline_utils import (
from diffusers.utils import logging
from ..testing_utils import (
CaptureLogger,
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
@@ -498,117 +492,6 @@ class ModularGuiderTesterMixin:
assert max_diff > expected_max_diff, "Output with CFG must be different from normal inference"
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
class TestModularModelCardContent:
def create_mock_block(self, name="TestBlock", description="Test block description"):
class MockBlock:

View File

@@ -24,14 +24,18 @@ import torch
from diffusers import FluxTransformer2DModel
from diffusers.modular_pipelines import (
ComponentSpec,
ConditionalPipelineBlocks,
InputParam,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
OutputParam,
PipelineState,
SequentialPipelineBlocks,
WanModularPipeline,
)
from diffusers.utils import logging
from ..testing_utils import nightly, require_torch, require_torch_accelerator, slow, torch_device
from ..testing_utils import CaptureLogger, nightly, require_torch, require_torch_accelerator, slow, torch_device
def _create_tiny_model_dir(model_dir):
@@ -463,6 +467,117 @@ class TestModularCustomBlocks:
assert output_prompt.startswith("Modular diffusers + ")
class TestCustomBlockRequirements:
def get_dummy_block_pipe(self):
class DummyBlockOne:
# keep two arbitrary deps so that we can test warnings.
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
# keep two dependencies that will be available during testing.
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
pipe = SequentialPipelineBlocks.from_blocks_dict(
{"dummy_block_one": DummyBlockOne, "dummy_block_two": DummyBlockTwo}
)
return pipe
def get_dummy_conditional_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
class DummyConditionalBlocks(ConditionalPipelineBlocks):
block_classes = [DummyBlockOne, DummyBlockTwo]
block_names = ["block_one", "block_two"]
block_trigger_inputs = []
def select_block(self, **kwargs):
return "block_one"
return DummyConditionalBlocks()
def get_dummy_loop_block_pipe(self):
class DummyBlockOne:
_requirements = {"xyz": ">=0.8.0", "abc": ">=10.0.0"}
class DummyBlockTwo:
_requirements = {"transformers": ">=4.44.0", "diffusers": ">=0.2.0"}
return LoopSequentialPipelineBlocks.from_blocks_dict({"block_one": DummyBlockOne, "block_two": DummyBlockTwo})
def test_sequential_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
requirements = config["requirements"]
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == requirements
def test_sequential_block_requirements_warnings(self, tmp_path):
pipe = self.get_dummy_block_pipe()
logger = logging.get_logger("diffusers.modular_pipelines.modular_pipeline_utils")
logger.setLevel(30)
with CaptureLogger(logger) as cap_logger:
pipe.save_pretrained(str(tmp_path))
template = "{req} was specified in the requirements but wasn't found in the current environment"
msg_xyz = template.format(req="xyz")
msg_abc = template.format(req="abc")
assert msg_xyz in str(cap_logger.out)
assert msg_abc in str(cap_logger.out)
def test_conditional_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_conditional_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
def test_loop_block_requirements_save_load(self, tmp_path):
pipe = self.get_dummy_loop_block_pipe()
pipe.save_pretrained(str(tmp_path))
config_path = tmp_path / "modular_config.json"
with open(config_path, "r") as f:
config = json.load(f)
assert "requirements" in config
expected_requirements = {
"xyz": ">=0.8.0",
"abc": ">=10.0.0",
"transformers": ">=4.44.0",
"diffusers": ">=0.2.0",
}
assert expected_requirements == config["requirements"]
@slow
@nightly
@require_torch

View File

@@ -18,7 +18,7 @@ import unittest
import torch
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
from diffusers.training_utils import set_seed
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
from ..testing_utils import slow
@@ -85,3 +85,47 @@ class TrainingTests(unittest.TestCase):
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
def test_confidence_aware_loss(self):
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
labels = torch.tensor([[0, 0]])
weights = torch.tensor([[1.0, 2.0]])
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=0.0, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss, loss_sft))
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
lambda_conf = 0.25
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
)
# Manual expected values for the small 2-class case.
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
1, 2
)
expected_sft = (per_token_nll * weights).sum() / weights.sum()
pred = logits.argmax(dim=-1)
correct = pred.eq(labels)
log_probs = torch.log_softmax(logits.float(), dim=-1)
probs = log_probs.exp()
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
weights * correct.to(weights.dtype)
).sum().clamp_min(1)
expected = expected_sft + lambda_conf * expected_conf
self.assertTrue(torch.allclose(loss_sft, expected_sft))
self.assertTrue(torch.allclose(loss_conf, expected_conf))
self.assertTrue(torch.allclose(loss, expected))
# Temperature affects only the confidence term.
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
)
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))

View File

View File

@@ -0,0 +1,245 @@
import unittest
import torch
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
class _DummyModelOutput:
def __init__(self, logits):
self.logits = logits
class _DummyCausalLM(torch.nn.Module):
def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = int(vocab_size)
self.register_buffer("_device_anchor", torch.empty(0))
@property
def dtype(self):
return torch.float32
@property
def device(self):
return self._device_anchor.device
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
batch_size, seq_len = input_ids.shape
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
# Make confidence vary with token position so top-k commits are deterministic.
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
return _DummyModelOutput(logits=logits)
def _make_pipeline(tokenizer=None):
model = _DummyCausalLM(vocab_size=32)
scheduler = BlockRefinementScheduler()
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
class LLaDA2PipelineTest(unittest.TestCase):
def test_pipeline_runs(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
out = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=24,
block_length=8,
num_inference_steps=8,
temperature=0.0,
threshold=2.0, # force top-k commits
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
eos_token_id=None,
output_type="seq",
)
self.assertEqual(out.sequences.shape, (2, 24))
self.assertFalse((out.sequences == 31).any().item())
def test_pipeline_return_tuple(self):
pipe = _make_pipeline().to("cpu")
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
sequences, texts = pipe(
input_ids=input_ids,
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
return_dict=False,
)
self.assertEqual(sequences.shape, (1, 16))
self.assertIsNone(texts)
def test_output_type_seq(self):
"""output_type='seq' should return sequences but no texts."""
pipe = _make_pipeline().to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="seq",
)
self.assertIsNotNone(out.sequences)
self.assertEqual(out.sequences.shape, (1, 16))
self.assertIsNone(out.texts)
def test_output_type_text_without_tokenizer(self):
"""output_type='text' without a tokenizer should return texts=None."""
pipe = _make_pipeline(tokenizer=None).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
mask_token_id=31,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNone(out.texts)
def test_output_type_text_with_tokenizer(self):
"""output_type='text' with a tokenizer should return decoded texts."""
tok = type(
"Tok",
(),
{
"eos_token_id": None,
"mask_token_id": 31,
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
},
)()
pipe = _make_pipeline(tokenizer=tok).to("cpu")
out = pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
temperature=0.0,
threshold=2.0,
minimal_topk=1,
eos_early_stop=False,
output_type="text",
)
self.assertIsNotNone(out.sequences)
self.assertIsNotNone(out.texts)
self.assertEqual(len(out.texts), 1)
self.assertTrue(out.texts[0].startswith("decoded_"))
def test_output_type_invalid_raises(self):
"""Invalid output_type should raise ValueError."""
pipe = _make_pipeline().to("cpu")
with self.assertRaises(ValueError):
pipe(
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
use_chat_template=False,
gen_length=16,
block_length=8,
num_inference_steps=4,
mask_token_id=31,
output_type="invalid",
)
def test_prepare_input_ids_from_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertTrue(torch.equal(result, ids))
def test_prepare_input_ids_from_1d_tensor(self):
pipe = _make_pipeline()
ids = torch.tensor([1, 2, 3], dtype=torch.long)
result = pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=ids,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
self.assertEqual(result.shape, (1, 3))
def test_prepare_input_ids_no_tokenizer_raises(self):
pipe = _make_pipeline(tokenizer=None)
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
pipe = _make_pipeline()
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt="hello",
messages=[{"role": "user", "content": "hi"}],
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
def test_prepare_input_ids_neither_raises(self):
pipe = _make_pipeline()
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
with self.assertRaises(ValueError):
pipe._prepare_input_ids(
prompt=None,
messages=None,
input_ids=None,
use_chat_template=False,
add_generation_prompt=False,
chat_template_kwargs=None,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -1534,14 +1534,18 @@ class PipelineTesterMixin:
pipe.set_progress_bar_config(disable=None)
pipe.to("cpu")
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == "cpu" for device in model_devices))
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
self.assertTrue(np.isnan(output_cpu).sum() == 0)
pipe.to(torch_device)
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
model_devices = [
component.device.type for component in components.values() if getattr(component, "device", None)
]
self.assertTrue(all(device == torch_device for device in model_devices))
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -1552,11 +1556,11 @@ class PipelineTesterMixin:
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):

View File

@@ -0,0 +1,470 @@
import tempfile
import unittest
import torch
from diffusers import BlockRefinementScheduler
class BlockRefinementSchedulerTest(unittest.TestCase):
def get_scheduler(self, **kwargs):
config = {
"block_length": 32,
"num_inference_steps": 8,
"threshold": 0.95,
"editing_threshold": None,
"minimal_topk": 1,
}
config.update(kwargs)
return BlockRefinementScheduler(**config)
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
"""Create logits where softmax of the target token has approximately the given probability."""
batch_size, block_length = target_probs.shape
logits = torch.zeros(batch_size, block_length, vocab_size)
# Set token 0 as the "predicted" token with a logit proportional to desired probability
for b in range(batch_size):
for t in range(block_length):
p = target_probs[b, t].item()
if p > 0:
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
return logits
def test_set_timesteps(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
self.assertEqual(scheduler.num_inference_steps, 8)
self.assertEqual(len(scheduler.timesteps), 8)
self.assertEqual(scheduler.timesteps[0].item(), 7)
self.assertEqual(scheduler.timesteps[-1].item(), 0)
def test_set_timesteps_invalid(self):
scheduler = self.get_scheduler()
with self.assertRaises(ValueError):
scheduler.set_timesteps(0)
def test_get_num_transfer_tokens_even(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
self.assertEqual(schedule.sum().item(), 32)
self.assertEqual(len(schedule), 8)
self.assertTrue((schedule == 4).all().item())
def test_get_num_transfer_tokens_remainder(self):
scheduler = self.get_scheduler()
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
self.assertEqual(schedule.sum().item(), 10)
self.assertEqual(len(schedule), 3)
self.assertEqual(schedule[0].item(), 4)
self.assertEqual(schedule[1].item(), 3)
self.assertEqual(schedule[2].item(), 3)
def test_transfer_schedule_created_on_set_timesteps(self):
scheduler = self.get_scheduler(block_length=16)
scheduler.set_timesteps(4)
self.assertIsNotNone(scheduler._transfer_schedule)
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
def test_save_load_config_round_trip(self):
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
with tempfile.TemporaryDirectory() as tmpdir:
scheduler.save_config(tmpdir)
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
self.assertEqual(loaded.config.block_length, 64)
self.assertEqual(loaded.config.threshold, 0.8)
self.assertEqual(loaded.config.editing_threshold, 0.5)
self.assertEqual(loaded.config.minimal_topk, 2)
def test_from_config(self):
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
self.assertEqual(new_scheduler.config.block_length, 16)
self.assertEqual(new_scheduler.config.threshold, 0.7)
def test_step_commits_tokens(self):
"""Verify that step() commits mask tokens based on confidence."""
scheduler = self.get_scheduler(block_length=8)
scheduler.set_timesteps(2)
batch_size, block_length, vocab_size = 1, 8, 32
mask_id = 31
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
# Create logits where confidence decreases with position
logits = torch.zeros(batch_size, block_length, vocab_size)
for i in range(block_length):
logits[0, i, i] = 10.0 - i # decreasing confidence
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
threshold=0.95,
return_dict=True,
)
# With 8 tokens and 2 steps, first step should commit 4 tokens
committed = out.transfer_index[0].sum().item()
self.assertEqual(committed, 4)
def test_step_no_editing_by_default(self):
"""Without editing_threshold, no non-mask tokens should be changed."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 10.0 # predict token 15 for all positions
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=None,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index.any().item())
self.assertFalse(out.transfer_index[0, 0].item())
self.assertFalse(out.transfer_index[0, 1].item())
def test_step_editing_replaces_tokens(self):
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
# Token 0: predict 50 (different from 10) with very high logit
logits[0, 0, 15] = 20.0
# Token 1: predict 20 (same as current)
logits[0, 1, 20] = 20.0
# Mask tokens
logits[0, 2, 5] = 5.0
logits[0, 3, 6] = 5.0
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
return_dict=True,
)
# Token 0 should be edited (different prediction, high confidence)
self.assertTrue(out.editing_transfer_index[0, 0].item())
# Token 1 should NOT be edited (same prediction)
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_prompt_mask_prevents_editing(self):
"""Prompt positions should never be edited even with editing enabled."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
logits = torch.zeros(1, 4, vocab_size)
logits[0, :, 15] = 20.0
prompt_mask = torch.tensor([True, True, False, False])
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
editing_threshold=0.5,
prompt_mask=prompt_mask,
return_dict=True,
)
self.assertFalse(out.editing_transfer_index[0, 0].item())
self.assertFalse(out.editing_transfer_index[0, 1].item())
def test_step_return_tuple(self):
"""Verify tuple output when return_dict=False."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
vocab_size = 32
sample = torch.full((1, 4), 31, dtype=torch.long)
logits = torch.randn(1, 4, vocab_size)
result = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=31,
temperature=0.0,
return_dict=False,
)
self.assertIsInstance(result, tuple)
self.assertEqual(len(result), 5)
def test_step_batched(self):
"""Verify step works with batch_size > 1."""
scheduler = self.get_scheduler(block_length=4)
scheduler.set_timesteps(2)
batch_size, vocab_size = 3, 32
mask_id = 31
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
logits = torch.randn(batch_size, 4, vocab_size)
out = scheduler.step(
model_output=logits,
timestep=0,
sample=sample,
mask_token_id=mask_id,
temperature=0.0,
return_dict=True,
)
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
def test_check_block_should_continue_finished(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([True, True])
result = scheduler.check_block_should_continue(
step_idx=0,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_no_masks_no_edits(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=5,
masks_remaining=False,
editing_enabled=True,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=1,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_block_should_continue_steps_exhausted(self):
scheduler = self.get_scheduler()
scheduler.set_timesteps(8)
finished = torch.tensor([False])
result = scheduler.check_block_should_continue(
step_idx=8,
masks_remaining=True,
editing_enabled=False,
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
post_steps=0,
max_post_steps=16,
finished=finished,
)
self.assertFalse(result)
def test_check_eos_finished_marks_batch(self):
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, token, eos, mask, mask]
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
final_transfer = torch.tensor([[False, False, False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertTrue(finished[0].item())
def test_check_eos_finished_ignores_when_masks_before_eos(self):
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
mask_id, eos_id, prompt_length = 99, 2, 2
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, True]])
finished = torch.tensor([False])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=prompt_length,
)
self.assertFalse(finished[0].item())
def test_check_eos_finished_already_finished(self):
"""Already-finished batches should stay finished."""
mask_id, eos_id = 99, 2
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
final_transfer = torch.tensor([[False, False]])
finished = torch.tensor([True])
finished = BlockRefinementScheduler.check_eos_finished(
cur_x=cur_x,
sampled_tokens=sampled_tokens,
final_transfer=final_transfer,
finished=finished,
eos_token_id=eos_id,
mask_token_id=mask_id,
prompt_length=2,
)
self.assertTrue(finished[0].item())
def test_add_noise(self):
scheduler = self.get_scheduler(block_length=4)
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
attention_mask = torch.ones_like(input_ids)
mask_token_id = 99
gen = torch.Generator().manual_seed(42)
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
input_ids,
attention_mask,
prompt_length=2,
block_length=4,
mask_token_id=mask_token_id,
generator=gen,
)
# Prompt positions should never be masked
self.assertFalse(masked[0, 0].item())
self.assertFalse(masked[0, 1].item())
self.assertFalse(masked_rev[0, 0].item())
self.assertFalse(masked_rev[0, 1].item())
# Noisy should have mask_token_id where masked is True
self.assertTrue((noisy[masked] == mask_token_id).all().item())
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
# masked and masked_rev should be complementary within valid non-prompt positions
non_prompt = torch.zeros_like(masked)
non_prompt[0, 2:] = True
combined = masked | masked_rev
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
class TestTopPFiltering(unittest.TestCase):
def test_top_p_filtering(self):
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
def test_top_p_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
self.assertTrue(torch.equal(result, logits))
def test_top_p_filtering_one(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
self.assertTrue(torch.equal(result, logits))
class TestTopKFiltering(unittest.TestCase):
def test_top_k_filtering(self):
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
def test_top_k_filtering_none(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_zero(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
self.assertTrue(torch.equal(result, logits))
def test_top_k_filtering_large_k(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
self.assertTrue(torch.equal(result, logits))
class TestSampleFromLogits(unittest.TestCase):
def test_greedy_sampling(self):
logits = torch.tensor([[1.0, 5.0, 2.0]])
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 1)
self.assertEqual(tokens.shape, (1,))
self.assertEqual(probs.shape, (1,))
def test_multinomial_sampling(self):
logits = torch.tensor([[0.0, 100.0, -100.0]])
gen = torch.Generator().manual_seed(42)
tokens, probs = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=1.0,
top_k=None,
top_p=None,
generator=gen,
use_multinomial=True,
)
self.assertEqual(tokens.item(), 1)
def test_temperature_scaling(self):
logits = torch.tensor([[1.0, 2.0, 3.0]])
tokens, _ = BlockRefinementScheduler._sample_from_logits(
logits,
temperature=0.01,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
self.assertEqual(tokens.item(), 2)
def test_negative_temperature_raises(self):
logits = torch.tensor([[1.0, 2.0]])
with self.assertRaises(ValueError):
BlockRefinementScheduler._sample_from_logits(
logits,
temperature=-1.0,
top_k=None,
top_p=None,
generator=None,
use_multinomial=False,
)
if __name__ == "__main__":
unittest.main()

View File

@@ -43,7 +43,7 @@ def filter_pipelines(usage_dict, usage_cutoff=10000):
def fetch_pipeline_objects():
models = api.list_models(library="diffusers")
models = api.list_models(filter="diffusers")
downloads = defaultdict(int)
for model in models: