Files
diffusers/examples/discrete_diffusion
Kashif Rasul 5d207e756e [Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion

Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
- BlockRefinementPipeline: block-wise iterative refinement with confidence-based
  token commitment, supporting editing threshold for LLaDA2.1 models
- LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults
- DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p,
  temperature) and prompt/prefix helpers
- compute_confidence_aware_loss: CAP-style training loss
- Examples: sampling scripts for LLaDA2 and block refinement, training scripts
  with Qwen causal LM
- Docs and tests included

* feat: add BlockRefinementScheduler for commit-by-confidence scheduling

Extract the confidence-based token commit logic from BlockRefinementPipeline
into a dedicated BlockRefinementScheduler, following diffusers conventions.

The scheduler owns:
- Transfer schedule computation (get_num_transfer_tokens)
- Timestep management (set_timesteps)
- Step logic: confidence-based mask-filling and optional token editing

The pipeline now delegates scheduling to self.scheduler.step() and accepts
a scheduler parameter in __init__.

* test: add unit tests for BlockRefinementScheduler

12 tests covering set_timesteps, get_num_transfer_tokens, step logic
(confidence-based commits, threshold behavior, editing, prompt masking,
batched inputs, tuple output).

* docs: add toctree entries and standalone scheduler doc page

- Add BlockRefinement and LLaDA2 to docs sidebar navigation
- Add BlockRefinementScheduler to schedulers sidebar navigation
- Move scheduler autodoc to its own page under api/schedulers/

* feat: add --revision flag and fix dtype deprecation in sample_llada2.py

- Add --revision argument for loading model revisions from the Hub
- Replace deprecated torch_dtype with dtype for transformers 5.x compat

* fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat

LLaDA2 models expect a boolean-style (1/0) attention mask, not an
additive (0/-inf) mask. The model internally converts to additive,
so passing 0/-inf caused double-masking and gibberish output.

* refactor: consolidate training scripts into single train_block_refinement.py

- Remove toy train_block_refinement_cap.py (self-contained demo with tiny model)
- Rename train_block_refinement_qwen_cap.py to train_block_refinement.py
  (already works with any causal LM via AutoModelForCausalLM)
- Fix torch_dtype deprecation and update README with correct script names

* fix formatting

* docs: improve LLaDA2 and BlockRefinement documentation

- Add usage examples with real model IDs and working code
- Add recommended parameters table for LLaDA2.1 quality/speed modes
- Note that editing is LLaDA2.1-only (not for LLaDA2.0 models)
- Remove misleading config defaults section from BlockRefinement docs

* feat: set LLaDA2Pipeline defaults to recommended model parameters

- threshold: 0.95 -> 0.7 (quality mode)
- max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0)
- eos_early_stop: False -> True (stop at EOS token)

block_length=32, steps=32, temperature=0.0 were already correct.
editing_threshold remains None (users enable for LLaDA2.1 models).

* feat: default editing_threshold=0.5 for LLaDA2.1 quality mode

LLaDA2.1 is the current generation. Users with LLaDA2.0 models can
disable editing by passing editing_threshold=None.

* fix: align sampling utilities with official LLaDA2 implementation

- top_p filtering: add shift-right to preserve at least one token above
  threshold (matches official code line 1210)
- temperature ordering: apply scaling before top-k/top-p filtering so
  filtering operates on scaled logits (matches official code lines 1232-1235)
- greedy branch: return argmax directly when temperature=0 without
  filtering (matches official code lines 1226-1230)

* refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids

LLaDA2Pipeline._prepare_prompt_ids was a near-copy of
DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate
and call the mixin method directly. Also simplify _extract_input_ids
since we always pass return_dict=True.

* formatting

* fix: replace deprecated torch_dtype with dtype in examples and docstrings

- Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID
- Fix sample_block_refinement.py to use dtype=

* remove BlockRefinementPipeline

* cleanup

* fix readme

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* removed DiscreteDiffusionPipelineMixin

* add support for 2d masks for flash attn

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix issues from review

* added tests

* formatting

* add check_eos_finished to scheduler

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix renaming issues and types

* remove duplicate check

* Update docs/source/en/api/pipelines/llada2.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-25 16:17:50 +05:30
..

Discrete Token Diffusion (Experimental)

This folder contains training and sampling examples for discrete diffusion over token IDs (language-model style), built to follow the diffusers + accelerate training conventions.

LLaDA2

LLaDA2 generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.

Train

The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):

accelerate launch examples/discrete_diffusion/train_llada2.py \
  --model_name_or_path Qwen/Qwen2.5-0.5B \
  --dataset_name wikitext \
  --dataset_config_name wikitext-2-raw-v1 \
  --text_column text \
  --output_dir llada2-output \
  --max_train_steps 1000 \
  --prompt_length 32 \
  --block_length 32 \
  --lambda_conf 2.0 \
  --conf_temperature 0.5

If you don't want to download a dataset, you can use random-token data:

accelerate launch examples/discrete_diffusion/train_llada2.py \
  --model_name_or_path Qwen/Qwen2.5-0.5B \
  --output_dir llada2-output \
  --use_dummy_data \
  --num_dummy_samples 2048

Sample

python examples/discrete_diffusion/sample_llada2.py \
  --model_id inclusionAI/LLaDA2.1-mini \
  --prompt "Write a short poem about the ocean." \
  --gen_length 256 \
  --num_inference_steps 32 \
  --threshold 0.7 \
  --editing_threshold 0.5 \
  --max_post_steps 16 \
  --use_chat_template \
  --add_generation_prompt