mirror of
https://github.com/huggingface/diffusers.git
synced 2026-04-22 23:51:29 +08:00
* feat: add LLaDA2 and BlockRefinement pipelines for discrete text diffusion Add support for LLaDA2/LLaDA2.1 discrete diffusion text generation: - BlockRefinementPipeline: block-wise iterative refinement with confidence-based token commitment, supporting editing threshold for LLaDA2.1 models - LLaDA2Pipeline: convenience wrapper with LLaDA2-specific defaults - DiscreteDiffusionPipelineMixin: shared SAR sampling utilities (top-k, top-p, temperature) and prompt/prefix helpers - compute_confidence_aware_loss: CAP-style training loss - Examples: sampling scripts for LLaDA2 and block refinement, training scripts with Qwen causal LM - Docs and tests included * feat: add BlockRefinementScheduler for commit-by-confidence scheduling Extract the confidence-based token commit logic from BlockRefinementPipeline into a dedicated BlockRefinementScheduler, following diffusers conventions. The scheduler owns: - Transfer schedule computation (get_num_transfer_tokens) - Timestep management (set_timesteps) - Step logic: confidence-based mask-filling and optional token editing The pipeline now delegates scheduling to self.scheduler.step() and accepts a scheduler parameter in __init__. * test: add unit tests for BlockRefinementScheduler 12 tests covering set_timesteps, get_num_transfer_tokens, step logic (confidence-based commits, threshold behavior, editing, prompt masking, batched inputs, tuple output). * docs: add toctree entries and standalone scheduler doc page - Add BlockRefinement and LLaDA2 to docs sidebar navigation - Add BlockRefinementScheduler to schedulers sidebar navigation - Move scheduler autodoc to its own page under api/schedulers/ * feat: add --revision flag and fix dtype deprecation in sample_llada2.py - Add --revision argument for loading model revisions from the Hub - Replace deprecated torch_dtype with dtype for transformers 5.x compat * fix: use 1/0 attention mask instead of 0/-inf for LLaDA2 compat LLaDA2 models expect a boolean-style (1/0) attention mask, not an additive (0/-inf) mask. The model internally converts to additive, so passing 0/-inf caused double-masking and gibberish output. * refactor: consolidate training scripts into single train_block_refinement.py - Remove toy train_block_refinement_cap.py (self-contained demo with tiny model) - Rename train_block_refinement_qwen_cap.py to train_block_refinement.py (already works with any causal LM via AutoModelForCausalLM) - Fix torch_dtype deprecation and update README with correct script names * fix formatting * docs: improve LLaDA2 and BlockRefinement documentation - Add usage examples with real model IDs and working code - Add recommended parameters table for LLaDA2.1 quality/speed modes - Note that editing is LLaDA2.1-only (not for LLaDA2.0 models) - Remove misleading config defaults section from BlockRefinement docs * feat: set LLaDA2Pipeline defaults to recommended model parameters - threshold: 0.95 -> 0.7 (quality mode) - max_post_steps: 0 -> 16 (recommended for LLaDA2.1, harmless for 2.0) - eos_early_stop: False -> True (stop at EOS token) block_length=32, steps=32, temperature=0.0 were already correct. editing_threshold remains None (users enable for LLaDA2.1 models). * feat: default editing_threshold=0.5 for LLaDA2.1 quality mode LLaDA2.1 is the current generation. Users with LLaDA2.0 models can disable editing by passing editing_threshold=None. * fix: align sampling utilities with official LLaDA2 implementation - top_p filtering: add shift-right to preserve at least one token above threshold (matches official code line 1210) - temperature ordering: apply scaling before top-k/top-p filtering so filtering operates on scaled logits (matches official code lines 1232-1235) - greedy branch: return argmax directly when temperature=0 without filtering (matches official code lines 1226-1230) * refactor: remove duplicate prompt encoding, reuse mixin's _prepare_input_ids LLaDA2Pipeline._prepare_prompt_ids was a near-copy of DiscreteDiffusionPipelineMixin._prepare_input_ids. Remove the duplicate and call the mixin method directly. Also simplify _extract_input_ids since we always pass return_dict=True. * formatting * fix: replace deprecated torch_dtype with dtype in examples and docstrings - Update EXAMPLE_DOC_STRING to use dtype= and LLaDA2.1-mini model ID - Fix sample_block_refinement.py to use dtype= * remove BlockRefinementPipeline * cleanup * fix readme * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: YiYi Xu <yixu310@gmail.com> * removed DiscreteDiffusionPipelineMixin * add support for 2d masks for flash attn * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/training_utils.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix issues from review * added tests * formatting * add check_eos_finished to scheduler * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/schedulers/scheduling_block_refinement.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * fix renaming issues and types * remove duplicate check * Update docs/source/en/api/pipelines/llada2.md Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> * Update src/diffusers/pipelines/llada2/pipeline_llada2.py Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com> --------- Co-authored-by: YiYi Xu <yixu310@gmail.com> Co-authored-by: dg845 <58458699+dg845@users.noreply.github.com>
322 lines
12 KiB
Python
322 lines
12 KiB
Python
#!/usr/bin/env python
|
|
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
import argparse
|
|
import math
|
|
import os
|
|
from dataclasses import asdict, dataclass
|
|
from typing import Dict, Optional
|
|
|
|
import torch
|
|
from accelerate import Accelerator
|
|
from accelerate.logging import get_logger
|
|
from accelerate.utils import ProjectConfiguration, set_seed
|
|
from datasets import load_dataset
|
|
from torch.utils.data import DataLoader
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
|
|
|
|
from diffusers import BlockRefinementScheduler
|
|
from diffusers.training_utils import compute_confidence_aware_loss
|
|
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class TrainConfig:
|
|
model_name_or_path: str
|
|
dataset_name: str
|
|
dataset_config_name: Optional[str]
|
|
text_column: str
|
|
cache_dir: Optional[str]
|
|
use_dummy_data: bool
|
|
num_dummy_samples: int
|
|
|
|
output_dir: str
|
|
seed: int
|
|
max_train_steps: int
|
|
checkpointing_steps: int
|
|
logging_steps: int
|
|
|
|
per_device_train_batch_size: int
|
|
gradient_accumulation_steps: int
|
|
learning_rate: float
|
|
weight_decay: float
|
|
lr_scheduler: str
|
|
lr_warmup_steps: int
|
|
|
|
max_length: int
|
|
prompt_length: int
|
|
block_length: int
|
|
|
|
lambda_conf: float
|
|
conf_temperature: float
|
|
|
|
|
|
def parse_args() -> TrainConfig:
|
|
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
|
|
|
|
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
|
|
parser.add_argument("--dataset_name", type=str, default="wikitext")
|
|
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
|
|
parser.add_argument("--text_column", type=str, default="text")
|
|
parser.add_argument("--cache_dir", type=str, default=None)
|
|
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
|
|
parser.add_argument("--num_dummy_samples", type=int, default=2048)
|
|
|
|
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
|
|
parser.add_argument("--seed", type=int, default=0)
|
|
parser.add_argument("--max_train_steps", type=int, default=1000)
|
|
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
|
parser.add_argument("--logging_steps", type=int, default=50)
|
|
|
|
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
|
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
|
parser.add_argument("--weight_decay", type=float, default=0.0)
|
|
parser.add_argument(
|
|
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
|
|
)
|
|
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
|
|
|
parser.add_argument("--max_length", type=int, default=256)
|
|
parser.add_argument("--prompt_length", type=int, default=32)
|
|
parser.add_argument("--block_length", type=int, default=32)
|
|
|
|
parser.add_argument("--lambda_conf", type=float, default=2.0)
|
|
parser.add_argument("--conf_temperature", type=float, default=0.5)
|
|
|
|
args = parser.parse_args()
|
|
return TrainConfig(**vars(args))
|
|
|
|
|
|
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
|
|
texts = examples[text_column]
|
|
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
|
|
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
|
|
|
|
|
|
class RandomTokenDataset(torch.utils.data.Dataset):
|
|
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
|
|
self.num_samples = int(num_samples)
|
|
self.seq_len = int(seq_len)
|
|
self.vocab_size = int(vocab_size)
|
|
self.pad_token_id = int(pad_token_id)
|
|
|
|
def __len__(self):
|
|
return self.num_samples
|
|
|
|
def __getitem__(self, idx):
|
|
del idx
|
|
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
|
attention_mask = torch.ones_like(input_ids)
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
|
|
|
|
|
def main():
|
|
cfg = parse_args()
|
|
if cfg.prompt_length >= cfg.max_length:
|
|
raise ValueError("`prompt_length` must be < `max_length`.")
|
|
if cfg.block_length <= 0:
|
|
raise ValueError("`block_length` must be > 0.")
|
|
|
|
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
|
|
accelerator = Accelerator(
|
|
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
|
project_config=project_config,
|
|
)
|
|
if accelerator.is_main_process:
|
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
accelerator.wait_for_everyone()
|
|
|
|
set_seed(cfg.seed)
|
|
logger.info("Training configuration: %s", asdict(cfg))
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
|
|
if tokenizer.pad_token_id is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
if tokenizer.mask_token_id is None:
|
|
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
|
|
|
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
|
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
|
|
model.resize_token_embeddings(len(tokenizer))
|
|
if load_dtype == torch.float32:
|
|
model.to(dtype=torch.float32)
|
|
|
|
mask_token_id = int(tokenizer.mask_token_id)
|
|
|
|
if cfg.use_dummy_data:
|
|
dataset = RandomTokenDataset(
|
|
num_samples=cfg.num_dummy_samples,
|
|
seq_len=cfg.max_length,
|
|
vocab_size=len(tokenizer),
|
|
pad_token_id=int(tokenizer.pad_token_id),
|
|
)
|
|
train_dataloader = DataLoader(
|
|
dataset,
|
|
shuffle=True,
|
|
batch_size=cfg.per_device_train_batch_size,
|
|
drop_last=True,
|
|
)
|
|
else:
|
|
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
|
|
if "train" not in raw_datasets:
|
|
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
|
|
|
|
with accelerator.main_process_first():
|
|
tokenized = raw_datasets["train"].map(
|
|
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
|
|
batched=True,
|
|
remove_columns=raw_datasets["train"].column_names,
|
|
desc="Tokenizing",
|
|
)
|
|
|
|
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
|
|
train_dataloader = DataLoader(
|
|
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
|
|
)
|
|
|
|
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
|
|
|
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
|
|
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
|
|
|
|
lr_scheduler = get_scheduler(
|
|
name=cfg.lr_scheduler,
|
|
optimizer=optimizer,
|
|
num_warmup_steps=cfg.lr_warmup_steps,
|
|
num_training_steps=cfg.max_train_steps,
|
|
)
|
|
|
|
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
|
model, optimizer, train_dataloader, lr_scheduler
|
|
)
|
|
|
|
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
|
|
|
|
global_step = 0
|
|
model.train()
|
|
|
|
for _epoch in range(num_train_epochs):
|
|
for batch in train_dataloader:
|
|
with accelerator.accumulate(model):
|
|
input_ids = batch["input_ids"]
|
|
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
|
|
|
|
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
|
|
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
|
|
input_ids,
|
|
attention_mask,
|
|
prompt_length=cfg.prompt_length,
|
|
block_length=cfg.block_length,
|
|
mask_token_id=mask_token_id,
|
|
generator=gen,
|
|
)
|
|
|
|
position_ids = (
|
|
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
|
|
)
|
|
|
|
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
|
|
logits_rev = model(
|
|
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
|
|
).logits
|
|
|
|
logits = logits.clone()
|
|
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
|
|
logits_rev = logits_rev.clone()
|
|
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
|
|
|
|
valid = attention_mask.to(dtype=torch.bool)
|
|
masked = masked & valid
|
|
masked_rev = masked_rev & valid
|
|
|
|
labels = input_ids.clone()
|
|
labels[~masked] = -100
|
|
labels_rev = input_ids.clone()
|
|
labels_rev[~masked_rev] = -100
|
|
|
|
weights = masked.to(dtype=logits.dtype)
|
|
weights_rev = masked_rev.to(dtype=logits.dtype)
|
|
|
|
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
|
logits,
|
|
labels,
|
|
lambda_conf=cfg.lambda_conf,
|
|
temperature=cfg.conf_temperature,
|
|
per_token_weights=weights,
|
|
)
|
|
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
|
|
logits_rev,
|
|
labels_rev,
|
|
lambda_conf=cfg.lambda_conf,
|
|
temperature=cfg.conf_temperature,
|
|
per_token_weights=weights_rev,
|
|
)
|
|
|
|
total_loss = loss + loss_rev
|
|
accelerator.backward(total_loss)
|
|
optimizer.step()
|
|
lr_scheduler.step()
|
|
optimizer.zero_grad(set_to_none=True)
|
|
|
|
if accelerator.sync_gradients:
|
|
global_step += 1
|
|
|
|
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
|
|
logger.info(
|
|
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
|
|
global_step,
|
|
total_loss.item(),
|
|
(loss_sft + loss_sft_rev).item(),
|
|
(loss_conf + loss_conf_rev).item(),
|
|
lr_scheduler.get_last_lr()[0],
|
|
)
|
|
print(
|
|
f"step={global_step} loss={total_loss.item():.4f} "
|
|
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
|
|
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
|
|
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
|
|
)
|
|
|
|
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
|
|
tokenizer.save_pretrained(save_dir)
|
|
|
|
if global_step >= cfg.max_train_steps:
|
|
break
|
|
|
|
if global_step >= cfg.max_train_steps:
|
|
break
|
|
|
|
accelerator.wait_for_everyone()
|
|
if accelerator.is_main_process:
|
|
final_dir = os.path.join(cfg.output_dir, "final")
|
|
os.makedirs(final_dir, exist_ok=True)
|
|
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
|
|
tokenizer.save_pretrained(final_dir)
|
|
|
|
logger.info("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|