mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-11 02:01:57 +08:00
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included * feat: add BlockRefinementScheduler for commit-by-confidence scheduling Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__. * test: add unit tests for BlockRefinementScheduler 12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output). * docs: add toctree entries and standalone scheduler doc page - Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/ * feat: add --revision flag and fix dtype deprecation in sample_llada2.py - Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat * fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output. * refactor: consolidate training scripts into single train_block_refinement.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names * fix formatting * docs: improve LLaDA2 and BlockRefinement documentation - Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs * feat: set LLaDA2Pipeline defaults to recommended model parameters - threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models). * feat: default editing_threshold=0.5 for LLaDA2.1 quality mode LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None. * fix: align sampling utilities with official LLaDA2 implementation - top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230) * refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True. * formatting * fix: replace deprecated torch_dtype with dtype in examples and docstrings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype= * remove BlockRefinementPipeline * cleanup * fix readme * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * removed DiscreteDiffusionPipelineMixin * add support for 2d masks for flash attn * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix issues from review * added tests * formatting * add check_eos_finished to scheduler * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix renaming issues and types * remove duplicate check * Update docs/source/en/api/pipelines/llada2.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
264 lines
8.2 KiB
Python
264 lines
8.2 KiB
Python
#!/usr/bin/env python
|
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Sample script for LLaDA2-style discrete diffusion text generation.
|
|
|
|
This script demonstrates how to use the LLaDA2Pipeline for text generation
|
|
using block-wise iterative refinement.
|
|
|
|
Example usage:
|
|
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
|
|
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
|
|
"""
|
|
|
|
import argparse
|
|
|
|
import torch
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
|
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
|
from diffusers.hooks import apply_group_offloading
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
|
|
)
|
|
parser.add_argument(
|
|
"--model_id",
|
|
type=str,
|
|
default="inclusionAI/LLaDA2.0-mini",
|
|
help="HuggingFace model ID or path to local model.",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
type=str,
|
|
default="Why does Camus think that Sisyphus is happy?",
|
|
help="Text prompt to generate from.",
|
|
)
|
|
parser.add_argument(
|
|
"--gen_length",
|
|
type=int,
|
|
default=2048,
|
|
help="Number of tokens to generate.",
|
|
)
|
|
parser.add_argument(
|
|
"--block_length",
|
|
type=int,
|
|
default=32,
|
|
help="Size of each generation block.",
|
|
)
|
|
parser.add_argument(
|
|
"--num_inference_steps",
|
|
type=int,
|
|
default=32,
|
|
help="Number of refinement steps per block.",
|
|
)
|
|
parser.add_argument(
|
|
"--temperature",
|
|
type=float,
|
|
default=0.0,
|
|
help="Sampling temperature (0.0 for greedy).",
|
|
)
|
|
parser.add_argument(
|
|
"--top_p",
|
|
type=float,
|
|
default=None,
|
|
help="Nucleus sampling probability threshold.",
|
|
)
|
|
parser.add_argument(
|
|
"--top_k",
|
|
type=int,
|
|
default=None,
|
|
help="Top-k sampling parameter.",
|
|
)
|
|
parser.add_argument(
|
|
"--threshold",
|
|
type=float,
|
|
default=0.95,
|
|
help="Confidence threshold for committing tokens.",
|
|
)
|
|
parser.add_argument(
|
|
"--editing_threshold",
|
|
type=float,
|
|
default=None,
|
|
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
|
|
)
|
|
parser.add_argument(
|
|
"--max_post_steps",
|
|
type=int,
|
|
default=0,
|
|
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
|
|
)
|
|
parser.add_argument(
|
|
"--sampling_method",
|
|
type=str,
|
|
default="multinomial",
|
|
choices=["auto", "greedy", "multinomial"],
|
|
help="Sampling method for block refinement.",
|
|
)
|
|
parser.add_argument(
|
|
"--eos_early_stop",
|
|
action="store_true",
|
|
help="Stop generation early when EOS token is generated.",
|
|
)
|
|
parser.add_argument(
|
|
"--use_chat_template",
|
|
action="store_true",
|
|
help="Use the tokenizer chat template for the prompt.",
|
|
)
|
|
parser.add_argument(
|
|
"--add_generation_prompt",
|
|
action="store_true",
|
|
help="Add the generation prompt when using the chat template.",
|
|
)
|
|
parser.add_argument(
|
|
"--device",
|
|
type=str,
|
|
default="cuda" if torch.cuda.is_available() else "cpu",
|
|
help="Device to run inference on.",
|
|
)
|
|
parser.add_argument(
|
|
"--dtype",
|
|
type=str,
|
|
default="bfloat16",
|
|
choices=["float32", "float16", "bfloat16"],
|
|
help="Model dtype.",
|
|
)
|
|
parser.add_argument(
|
|
"--seed",
|
|
type=int,
|
|
default=None,
|
|
help="Random seed for reproducibility.",
|
|
)
|
|
parser.add_argument(
|
|
"--offload",
|
|
type=str,
|
|
default=None,
|
|
choices=["group", "sequential"],
|
|
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
|
|
)
|
|
parser.add_argument(
|
|
"--revision",
|
|
type=str,
|
|
default=None,
|
|
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Parse dtype
|
|
dtype_map = {
|
|
"float32": torch.float32,
|
|
"float16": torch.float16,
|
|
"bfloat16": torch.bfloat16,
|
|
}
|
|
torch_dtype = dtype_map[args.dtype]
|
|
|
|
print(f"Loading model: {args.model_id}")
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
|
|
|
|
# Load model with appropriate memory settings based on offload strategy
|
|
if args.offload == "group":
|
|
# For group offloading, load to CPU first then apply hooks
|
|
print("Using group offloading for memory efficiency...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_id,
|
|
trust_remote_code=True,
|
|
dtype=torch_dtype,
|
|
low_cpu_mem_usage=True,
|
|
revision=args.revision,
|
|
)
|
|
# Apply group offloading with CUDA streams for better performance
|
|
onload_device = torch.device(args.device)
|
|
offload_device = torch.device("cpu")
|
|
apply_group_offloading(
|
|
model,
|
|
onload_device=onload_device,
|
|
offload_device=offload_device,
|
|
offload_type="leaf_level",
|
|
use_stream=True,
|
|
)
|
|
elif args.offload == "sequential":
|
|
# For sequential offloading, load to CPU first
|
|
print("Using sequential CPU offloading (slower but lower memory)...")
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_id,
|
|
trust_remote_code=True,
|
|
dtype=torch_dtype,
|
|
low_cpu_mem_usage=True,
|
|
revision=args.revision,
|
|
)
|
|
# Sequential offloading will be applied via pipeline
|
|
else:
|
|
# Default: use device_map="auto" for automatic memory management
|
|
model = AutoModelForCausalLM.from_pretrained(
|
|
args.model_id,
|
|
trust_remote_code=True,
|
|
dtype=torch_dtype,
|
|
device_map="auto",
|
|
low_cpu_mem_usage=True,
|
|
revision=args.revision,
|
|
)
|
|
model.eval()
|
|
|
|
# Create pipeline
|
|
scheduler = BlockRefinementScheduler()
|
|
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
|
|
|
# Apply sequential CPU offload if requested
|
|
if args.offload == "sequential":
|
|
pipe.enable_sequential_cpu_offload()
|
|
|
|
# Set up generator for reproducibility
|
|
generator = None
|
|
if args.seed is not None:
|
|
generator = torch.Generator(device=args.device).manual_seed(args.seed)
|
|
|
|
print(f"\nPrompt: {args.prompt}")
|
|
print(
|
|
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
|
|
)
|
|
print("-" * 50)
|
|
|
|
# Generate
|
|
output = pipe(
|
|
prompt=args.prompt,
|
|
use_chat_template=args.use_chat_template,
|
|
add_generation_prompt=args.add_generation_prompt,
|
|
gen_length=args.gen_length,
|
|
block_length=args.block_length,
|
|
num_inference_steps=args.num_inference_steps,
|
|
temperature=args.temperature,
|
|
top_p=args.top_p,
|
|
top_k=args.top_k,
|
|
threshold=args.threshold,
|
|
editing_threshold=args.editing_threshold,
|
|
max_post_steps=args.max_post_steps,
|
|
sampling_method=args.sampling_method,
|
|
eos_early_stop=args.eos_early_stop,
|
|
generator=generator,
|
|
)
|
|
|
|
print("\nGenerated text:")
|
|
print(output.texts[0])
|
|
|
|
print(f"\nGenerated {output.sequences.shape[1]} tokens")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|