Files
diffusers/examples/discrete_diffusion/sample_llada2.py
Kashif Rasul 5d207e756e [Discrete Diffusion] Add LLaDA2 pipeline (#13226)
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion

Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation:
- BlockRefinementPipeline: block-wise iterative refinement with confidence-based
  token commitment, supporting editing threshold for LLaDA2.1 models
- LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults
- DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p,
  temperature) and prompt/prefix helpers
- compute_confidence_aware_loss: CAP-style training loss
- Examples: sampling scripts for LLaDA2 and block refinement, training scripts
  with Qwen causal LM
- Docs and tests included

* feat: add BlockRefinementScheduler for commit-by-confidence scheduling

Extract the confidence-based token commit logic from BlockRefinementPipeline
into a dedicated BlockRefinementScheduler, following diffusers conventions.

The scheduler owns:
- Transfer schedule computation (get_num_transfer_tokens)
- Timestep management (set_timesteps)
- Step logic: confidence-based mask-filling and optional token editing

The pipeline now delegates scheduling to self.scheduler.step() and accepts
a scheduler parameter in __init__.

* test: add unit tests for BlockRefinementScheduler

12 tests covering set_timesteps, get_num_transfer_tokens, step logic
(confidence-based commits, threshold behavior, editing, prompt masking,
batched inputs, tuple output).

* docs: add toctree entries and standalone scheduler doc page

- Add BlockRefinement and LLaDA2 to docs sidebar navigation
- Add BlockRefinementScheduler to schedulers sidebar navigation
- Move scheduler autodoc to its own page under api/schedulers/

* feat: add --revision flag and fix dtype deprecation in sample_llada2.py

- Add --revision argument for loading model revisions from the Hub
- Replace deprecated torch_dtype with dtype for transformers 5.x compat

* fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat

LLaDA2 models expect a boolean-style (1/0) attention mask, not an
additive (0/-inf) mask. The model internally converts to additive,
so passing 0/-inf caused double-masking and gibberish output.

* refactor: consolidate training scripts into single train_block_refinement.py

- Remove toy train_block_refinement_cap.py (self-contained demo with tiny model)
- Rename train_block_refinement_qwen_cap.py to train_block_refinement.py
  (already works with any causal LM via AutoModelForCausalLM)
- Fix torch_dtype deprecation and update README with correct script names

* fix formatting

* docs: improve LLaDA2 and BlockRefinement documentation

- Add usage examples with real model IDs and working code
- Add recommended parameters table for LLaDA2.1 quality/speed modes
- Note that editing is LLaDA2.1-only (not for LLaDA2.0 models)
- Remove misleading config defaults section from BlockRefinement docs

* feat: set LLaDA2Pipeline defaults to recommended model parameters

- threshold: 0.95 -> 0.7 (quality mode)
- max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0)
- eos_early_stop: False -> True (stop at EOS token)

block_length=32, steps=32, temperature=0.0 were already correct.
editing_threshold remains None (users enable for LLaDA2.1 models).

* feat: default editing_threshold=0.5 for LLaDA2.1 quality mode

LLaDA2.1 is the current generation. Users with LLaDA2.0 models can
disable editing by passing editing_threshold=None.

* fix: align sampling utilities with official LLaDA2 implementation

- top_p filtering: add shift-right to preserve at least one token above
  threshold (matches official code line 1210)
- temperature ordering: apply scaling before top-k/top-p filtering so
  filtering operates on scaled logits (matches official code lines 1232-1235)
- greedy branch: return argmax directly when temperature=0 without
  filtering (matches official code lines 1226-1230)

* refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids

LLaDA2Pipeline._prepare_prompt_ids was a near-copy of
DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate
and call the mixin method directly. Also simplify _extract_input_ids
since we always pass return_dict=True.

* formatting

* fix: replace deprecated torch_dtype with dtype in examples and docstrings

- Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID
- Fix sample_block_refinement.py to use dtype=

* remove BlockRefinementPipeline

* cleanup

* fix readme

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: YiYi Xu <yixu310@gmail.com>

* removed DiscreteDiffusionPipelineMixin

* add support for 2d masks for flash attn

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/training_utils.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix issues from review

* added tests

* formatting

* add check_eos_finished to scheduler

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/schedulers/scheduling_block_refinement.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* fix renaming issues and types

* remove duplicate check

* Update docs/source/en/api/pipelines/llada2.md

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

* Update src/diffusers/pipelines/llada2/pipeline_llada2.py

Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
2026-03-25 16:17:50 +05:30

264 lines
8.2 KiB
Python

#!/usr/bin/env python
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Sample script for LLaDA2-style discrete diffusion text generation.
This script demonstrates how to use the LLaDA2Pipeline for text generation
using block-wise iterative refinement.
Example usage:
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
"""
import argparse
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
from diffusers.hooks import apply_group_offloading
def main():
parser = argparse.ArgumentParser(
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
)
parser.add_argument(
"--model_id",
type=str,
default="inclusionAI/LLaDA2.0-mini",
help="HuggingFace model ID or path to local model.",
)
parser.add_argument(
"--prompt",
type=str,
default="Why does Camus think that Sisyphus is happy?",
help="Text prompt to generate from.",
)
parser.add_argument(
"--gen_length",
type=int,
default=2048,
help="Number of tokens to generate.",
)
parser.add_argument(
"--block_length",
type=int,
default=32,
help="Size of each generation block.",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=32,
help="Number of refinement steps per block.",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="Sampling temperature (0.0 for greedy).",
)
parser.add_argument(
"--top_p",
type=float,
default=None,
help="Nucleus sampling probability threshold.",
)
parser.add_argument(
"--top_k",
type=int,
default=None,
help="Top-k sampling parameter.",
)
parser.add_argument(
"--threshold",
type=float,
default=0.95,
help="Confidence threshold for committing tokens.",
)
parser.add_argument(
"--editing_threshold",
type=float,
default=None,
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
)
parser.add_argument(
"--max_post_steps",
type=int,
default=0,
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
)
parser.add_argument(
"--sampling_method",
type=str,
default="multinomial",
choices=["auto", "greedy", "multinomial"],
help="Sampling method for block refinement.",
)
parser.add_argument(
"--eos_early_stop",
action="store_true",
help="Stop generation early when EOS token is generated.",
)
parser.add_argument(
"--use_chat_template",
action="store_true",
help="Use the tokenizer chat template for the prompt.",
)
parser.add_argument(
"--add_generation_prompt",
action="store_true",
help="Add the generation prompt when using the chat template.",
)
parser.add_argument(
"--device",
type=str,
default="cuda" if torch.cuda.is_available() else "cpu",
help="Device to run inference on.",
)
parser.add_argument(
"--dtype",
type=str,
default="bfloat16",
choices=["float32", "float16", "bfloat16"],
help="Model dtype.",
)
parser.add_argument(
"--seed",
type=int,
default=None,
help="Random seed for reproducibility.",
)
parser.add_argument(
"--offload",
type=str,
default=None,
choices=["group", "sequential"],
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
)
parser.add_argument(
"--revision",
type=str,
default=None,
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
)
args = parser.parse_args()
# Parse dtype
dtype_map = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}
torch_dtype = dtype_map[args.dtype]
print(f"Loading model: {args.model_id}")
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
# Load model with appropriate memory settings based on offload strategy
if args.offload == "group":
# For group offloading, load to CPU first then apply hooks
print("Using group offloading for memory efficiency...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Apply group offloading with CUDA streams for better performance
onload_device = torch.device(args.device)
offload_device = torch.device("cpu")
apply_group_offloading(
model,
onload_device=onload_device,
offload_device=offload_device,
offload_type="leaf_level",
use_stream=True,
)
elif args.offload == "sequential":
# For sequential offloading, load to CPU first
print("Using sequential CPU offloading (slower but lower memory)...")
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
low_cpu_mem_usage=True,
revision=args.revision,
)
# Sequential offloading will be applied via pipeline
else:
# Default: use device_map="auto" for automatic memory management
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
trust_remote_code=True,
dtype=torch_dtype,
device_map="auto",
low_cpu_mem_usage=True,
revision=args.revision,
)
model.eval()
# Create pipeline
scheduler = BlockRefinementScheduler()
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
# Apply sequential CPU offload if requested
if args.offload == "sequential":
pipe.enable_sequential_cpu_offload()
# Set up generator for reproducibility
generator = None
if args.seed is not None:
generator = torch.Generator(device=args.device).manual_seed(args.seed)
print(f"\nPrompt: {args.prompt}")
print(
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
)
print("-" * 50)
# Generate
output = pipe(
prompt=args.prompt,
use_chat_template=args.use_chat_template,
add_generation_prompt=args.add_generation_prompt,
gen_length=args.gen_length,
block_length=args.block_length,
num_inference_steps=args.num_inference_steps,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
threshold=args.threshold,
editing_threshold=args.editing_threshold,
max_post_steps=args.max_post_steps,
sampling_method=args.sampling_method,
eos_early_stop=args.eos_early_stop,
generator=generator,
)
print("\nGenerated text:")
print(output.texts[0])
print(f"\nGenerated {output.sequences.shape[1]} tokens")
if __name__ == "__main__":
main()