mirror of
https://github.com/huggingface/diffusers.git
synced 2026-03-27 02:47:41 +08:00
Compare commits
2 Commits
bria-test-
...
fix-torcha
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4e01e02395 | ||
|
|
5e5b575fb3 |
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -341,7 +341,7 @@ jobs:
|
||||
additional_deps: ["peft", "kernels"]
|
||||
- backend: "torchao"
|
||||
test_location: "torchao"
|
||||
additional_deps: []
|
||||
additional_deps: [mslk-cuda]
|
||||
- backend: "optimum_quanto"
|
||||
test_location: "quanto"
|
||||
additional_deps: []
|
||||
|
||||
@@ -670,10 +670,6 @@
|
||||
- local: api/pipelines/z_image
|
||||
title: Z-Image
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/llada2
|
||||
title: LLaDA2
|
||||
title: Text
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
@@ -722,8 +718,6 @@
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/block_refinement
|
||||
title: BlockRefinementScheduler
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
title: CMStochasticIterativeScheduler
|
||||
- local: api/schedulers/ddim_cogvideox
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) is a family of discrete diffusion language models
|
||||
that generate text through block-wise iterative refinement. Instead of autoregressive token-by-token generation,
|
||||
LLaDA2 starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement
|
||||
steps.
|
||||
|
||||
## Usage
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
output = pipe(
|
||||
prompt="Write a short poem about the ocean.",
|
||||
gen_length=256,
|
||||
block_length=32,
|
||||
num_inference_steps=32,
|
||||
threshold=0.7,
|
||||
editing_threshold=0.5,
|
||||
max_post_steps=16,
|
||||
temperature=0.0,
|
||||
)
|
||||
print(output.texts[0])
|
||||
```
|
||||
|
||||
## Callbacks
|
||||
|
||||
Callbacks run after each refinement step. Pass `callback_on_step_end_tensor_inputs` to select which tensors are
|
||||
included in `callback_kwargs`. In the current implementation, `block_x` (the sequence window being refined) and
|
||||
`transfer_index` (mask-filling commit mask) are provided; return `{"block_x": ...}` from the callback to replace the
|
||||
window.
|
||||
|
||||
```py
|
||||
def on_step_end(pipe, step, timestep, callback_kwargs):
|
||||
block_x = callback_kwargs["block_x"]
|
||||
# Inspect or modify `block_x` here.
|
||||
return {"block_x": block_x}
|
||||
|
||||
out = pipe(
|
||||
prompt="Write a short poem.",
|
||||
callback_on_step_end=on_step_end,
|
||||
callback_on_step_end_tensor_inputs=["block_x"],
|
||||
)
|
||||
```
|
||||
|
||||
## Recommended parameters
|
||||
|
||||
LLaDA2.1 models support two modes:
|
||||
|
||||
| Mode | `threshold` | `editing_threshold` | `max_post_steps` |
|
||||
|------|-------------|---------------------|------------------|
|
||||
| Quality | 0.7 | 0.5 | 16 |
|
||||
| Speed | 0.5 | `None` | 16 |
|
||||
|
||||
Pass `editing_threshold=None`, `0.0`, or a negative value to turn off post-mask editing.
|
||||
|
||||
For LLaDA2.0 models, disable editing by passing `editing_threshold=None` or `0.0`.
|
||||
|
||||
For all models: `block_length=32`, `temperature=0.0`, `num_inference_steps=32`.
|
||||
|
||||
## LLaDA2Pipeline
|
||||
[[autodoc]] LLaDA2Pipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## LLaDA2PipelineOutput
|
||||
[[autodoc]] pipelines.LLaDA2PipelineOutput
|
||||
@@ -63,7 +63,6 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
|
||||
| [Latent Diffusion](latent_diffusion) | text2image, super-resolution |
|
||||
| [Latte](latte) | text2image |
|
||||
| [LEDITS++](ledits_pp) | image editing |
|
||||
| [LLaDA2](llada2) | text2text |
|
||||
| [Lumina-T2X](lumina) | text2image |
|
||||
| [Marigold](marigold) | depth-estimation, normals-estimation, intrinsic-decomposition |
|
||||
| [MultiDiffusion](panorama) | text2image |
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# BlockRefinementScheduler
|
||||
|
||||
The `BlockRefinementScheduler` manages block-wise iterative refinement for discrete token diffusion. At each step it
|
||||
commits the most confident tokens and optionally edits already-committed tokens when the model predicts a different
|
||||
token with high confidence.
|
||||
|
||||
This scheduler is used by [`LLaDA2Pipeline`].
|
||||
|
||||
## BlockRefinementScheduler
|
||||
[[autodoc]] BlockRefinementScheduler
|
||||
|
||||
## BlockRefinementSchedulerOutput
|
||||
[[autodoc]] schedulers.scheduling_block_refinement.BlockRefinementSchedulerOutput
|
||||
@@ -1,50 +0,0 @@
|
||||
# Discrete Token Diffusion (Experimental)
|
||||
|
||||
This folder contains **training and sampling examples** for *discrete diffusion over token IDs* (language-model style), built to follow the `diffusers` + `accelerate` training conventions.
|
||||
|
||||
## LLaDA2
|
||||
|
||||
[LLaDA2](https://huggingface.co/collections/inclusionAI/llada21) generates text through block-wise iterative refinement. Instead of autoregressive token-by-token generation, it starts with a fully masked sequence and progressively unmasks tokens by confidence over multiple refinement steps.
|
||||
|
||||
### Train
|
||||
|
||||
The training script uses confidence-aware loss and works with any causal LM from the Hub (e.g. Qwen, Llama, Mistral):
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--dataset_name wikitext \
|
||||
--dataset_config_name wikitext-2-raw-v1 \
|
||||
--text_column text \
|
||||
--output_dir llada2-output \
|
||||
--max_train_steps 1000 \
|
||||
--prompt_length 32 \
|
||||
--block_length 32 \
|
||||
--lambda_conf 2.0 \
|
||||
--conf_temperature 0.5
|
||||
```
|
||||
|
||||
If you don't want to download a dataset, you can use random-token data:
|
||||
|
||||
```bash
|
||||
accelerate launch examples/discrete_diffusion/train_llada2.py \
|
||||
--model_name_or_path Qwen/Qwen2.5-0.5B \
|
||||
--output_dir llada2-output \
|
||||
--use_dummy_data \
|
||||
--num_dummy_samples 2048
|
||||
```
|
||||
|
||||
### Sample
|
||||
|
||||
```bash
|
||||
python examples/discrete_diffusion/sample_llada2.py \
|
||||
--model_id inclusionAI/LLaDA2.1-mini \
|
||||
--prompt "Write a short poem about the ocean." \
|
||||
--gen_length 256 \
|
||||
--num_inference_steps 32 \
|
||||
--threshold 0.7 \
|
||||
--editing_threshold 0.5 \
|
||||
--max_post_steps 16 \
|
||||
--use_chat_template \
|
||||
--add_generation_prompt
|
||||
```
|
||||
@@ -1,263 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Sample script for LLaDA2-style discrete diffusion text generation.
|
||||
|
||||
This script demonstrates how to use the LLaDA2Pipeline for text generation
|
||||
using block-wise iterative refinement.
|
||||
|
||||
Example usage:
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-mini --prompt "What is the capital of France?"
|
||||
python sample_llada2.py --model_id inclusionAI/LLaDA2.0-flash-CAP --prompt "Explain quantum computing." --temperature 0.7
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
from diffusers.hooks import apply_group_offloading
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate text using LLaDA2Pipeline with block-wise discrete diffusion."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_id",
|
||||
type=str,
|
||||
default="inclusionAI/LLaDA2.0-mini",
|
||||
help="HuggingFace model ID or path to local model.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt",
|
||||
type=str,
|
||||
default="Why does Camus think that Sisyphus is happy?",
|
||||
help="Text prompt to generate from.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen_length",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Number of tokens to generate.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--block_length",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Size of each generation block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_inference_steps",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Number of refinement steps per block.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="Sampling temperature (0.0 for greedy).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Nucleus sampling probability threshold.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top_k",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Top-k sampling parameter.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="Confidence threshold for committing tokens.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--editing_threshold",
|
||||
type=float,
|
||||
default=None,
|
||||
help="Confidence threshold for editing already-committed tokens. Set to enable post-mask editing (e.g. 0.5).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_post_steps",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Maximum post-mask editing iterations per block (e.g. 16). Only used when --editing_threshold is set.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sampling_method",
|
||||
type=str,
|
||||
default="multinomial",
|
||||
choices=["auto", "greedy", "multinomial"],
|
||||
help="Sampling method for block refinement.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eos_early_stop",
|
||||
action="store_true",
|
||||
help="Stop generation early when EOS token is generated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_chat_template",
|
||||
action="store_true",
|
||||
help="Use the tokenizer chat template for the prompt.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--add_generation_prompt",
|
||||
action="store_true",
|
||||
help="Add the generation prompt when using the chat template.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda" if torch.cuda.is_available() else "cpu",
|
||||
help="Device to run inference on.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
default="bfloat16",
|
||||
choices=["float32", "float16", "bfloat16"],
|
||||
help="Model dtype.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--offload",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["group", "sequential"],
|
||||
help="Memory offloading strategy: 'group' for group offloading (faster), 'sequential' for sequential CPU offload (slower but lower memory).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Model revision (branch, tag, or commit hash) to load from the Hub.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Parse dtype
|
||||
dtype_map = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
torch_dtype = dtype_map[args.dtype]
|
||||
|
||||
print(f"Loading model: {args.model_id}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_id, trust_remote_code=True, revision=args.revision)
|
||||
|
||||
# Load model with appropriate memory settings based on offload strategy
|
||||
if args.offload == "group":
|
||||
# For group offloading, load to CPU first then apply hooks
|
||||
print("Using group offloading for memory efficiency...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Apply group offloading with CUDA streams for better performance
|
||||
onload_device = torch.device(args.device)
|
||||
offload_device = torch.device("cpu")
|
||||
apply_group_offloading(
|
||||
model,
|
||||
onload_device=onload_device,
|
||||
offload_device=offload_device,
|
||||
offload_type="leaf_level",
|
||||
use_stream=True,
|
||||
)
|
||||
elif args.offload == "sequential":
|
||||
# For sequential offloading, load to CPU first
|
||||
print("Using sequential CPU offloading (slower but lower memory)...")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
# Sequential offloading will be applied via pipeline
|
||||
else:
|
||||
# Default: use device_map="auto" for automatic memory management
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
args.model_id,
|
||||
trust_remote_code=True,
|
||||
dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
low_cpu_mem_usage=True,
|
||||
revision=args.revision,
|
||||
)
|
||||
model.eval()
|
||||
|
||||
# Create pipeline
|
||||
scheduler = BlockRefinementScheduler()
|
||||
pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
# Apply sequential CPU offload if requested
|
||||
if args.offload == "sequential":
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
|
||||
# Set up generator for reproducibility
|
||||
generator = None
|
||||
if args.seed is not None:
|
||||
generator = torch.Generator(device=args.device).manual_seed(args.seed)
|
||||
|
||||
print(f"\nPrompt: {args.prompt}")
|
||||
print(
|
||||
f"Generating {args.gen_length} tokens with block_length={args.block_length}, steps={args.num_inference_steps}"
|
||||
)
|
||||
print("-" * 50)
|
||||
|
||||
# Generate
|
||||
output = pipe(
|
||||
prompt=args.prompt,
|
||||
use_chat_template=args.use_chat_template,
|
||||
add_generation_prompt=args.add_generation_prompt,
|
||||
gen_length=args.gen_length,
|
||||
block_length=args.block_length,
|
||||
num_inference_steps=args.num_inference_steps,
|
||||
temperature=args.temperature,
|
||||
top_p=args.top_p,
|
||||
top_k=args.top_k,
|
||||
threshold=args.threshold,
|
||||
editing_threshold=args.editing_threshold,
|
||||
max_post_steps=args.max_post_steps,
|
||||
sampling_method=args.sampling_method,
|
||||
eos_early_stop=args.eos_early_stop,
|
||||
generator=generator,
|
||||
)
|
||||
|
||||
print("\nGenerated text:")
|
||||
print(output.texts[0])
|
||||
|
||||
print(f"\nGenerated {output.sequences.shape[1]} tokens")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,321 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, get_scheduler
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
from diffusers.training_utils import compute_confidence_aware_loss
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
model_name_or_path: str
|
||||
dataset_name: str
|
||||
dataset_config_name: Optional[str]
|
||||
text_column: str
|
||||
cache_dir: Optional[str]
|
||||
use_dummy_data: bool
|
||||
num_dummy_samples: int
|
||||
|
||||
output_dir: str
|
||||
seed: int
|
||||
max_train_steps: int
|
||||
checkpointing_steps: int
|
||||
logging_steps: int
|
||||
|
||||
per_device_train_batch_size: int
|
||||
gradient_accumulation_steps: int
|
||||
learning_rate: float
|
||||
weight_decay: float
|
||||
lr_scheduler: str
|
||||
lr_warmup_steps: int
|
||||
|
||||
max_length: int
|
||||
prompt_length: int
|
||||
block_length: int
|
||||
|
||||
lambda_conf: float
|
||||
conf_temperature: float
|
||||
|
||||
|
||||
def parse_args() -> TrainConfig:
|
||||
parser = argparse.ArgumentParser(description="Train block-refinement with a confidence-aware loss on a causal LM.")
|
||||
|
||||
parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2.5-0.5B")
|
||||
parser.add_argument("--dataset_name", type=str, default="wikitext")
|
||||
parser.add_argument("--dataset_config_name", type=str, default="wikitext-2-raw-v1")
|
||||
parser.add_argument("--text_column", type=str, default="text")
|
||||
parser.add_argument("--cache_dir", type=str, default=None)
|
||||
parser.add_argument("--use_dummy_data", action="store_true", help="Use random-token data instead of downloading.")
|
||||
parser.add_argument("--num_dummy_samples", type=int, default=2048)
|
||||
|
||||
parser.add_argument("--output_dir", type=str, default="block-refinement-output")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--max_train_steps", type=int, default=1000)
|
||||
parser.add_argument("--checkpointing_steps", type=int, default=500)
|
||||
parser.add_argument("--logging_steps", type=int, default=50)
|
||||
|
||||
parser.add_argument("--per_device_train_batch_size", type=int, default=1)
|
||||
parser.add_argument("--gradient_accumulation_steps", type=int, default=8)
|
||||
parser.add_argument("--learning_rate", type=float, default=2e-5)
|
||||
parser.add_argument("--weight_decay", type=float, default=0.0)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler", type=str, default="cosine", choices=["linear", "cosine", "cosine_with_restarts"]
|
||||
)
|
||||
parser.add_argument("--lr_warmup_steps", type=int, default=100)
|
||||
|
||||
parser.add_argument("--max_length", type=int, default=256)
|
||||
parser.add_argument("--prompt_length", type=int, default=32)
|
||||
parser.add_argument("--block_length", type=int, default=32)
|
||||
|
||||
parser.add_argument("--lambda_conf", type=float, default=2.0)
|
||||
parser.add_argument("--conf_temperature", type=float, default=0.5)
|
||||
|
||||
args = parser.parse_args()
|
||||
return TrainConfig(**vars(args))
|
||||
|
||||
|
||||
def tokenize_fn(examples: Dict, tokenizer, text_column: str, max_length: int):
|
||||
texts = examples[text_column]
|
||||
texts = [t for t in texts if isinstance(t, str) and len(t.strip()) > 0]
|
||||
return tokenizer(texts, truncation=True, padding=False, max_length=max_length)
|
||||
|
||||
|
||||
class RandomTokenDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *, num_samples: int, seq_len: int, vocab_size: int, pad_token_id: int):
|
||||
self.num_samples = int(num_samples)
|
||||
self.seq_len = int(seq_len)
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.pad_token_id = int(pad_token_id)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, idx):
|
||||
del idx
|
||||
input_ids = torch.randint(0, self.vocab_size, (self.seq_len,), dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
||||
|
||||
|
||||
def main():
|
||||
cfg = parse_args()
|
||||
if cfg.prompt_length >= cfg.max_length:
|
||||
raise ValueError("`prompt_length` must be < `max_length`.")
|
||||
if cfg.block_length <= 0:
|
||||
raise ValueError("`block_length` must be > 0.")
|
||||
|
||||
project_config = ProjectConfiguration(project_dir=cfg.output_dir, logging_dir=os.path.join(cfg.output_dir, "logs"))
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
project_config=project_config,
|
||||
)
|
||||
if accelerator.is_main_process:
|
||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
set_seed(cfg.seed)
|
||||
logger.info("Training configuration: %s", asdict(cfg))
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True, cache_dir=cfg.cache_dir)
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
if tokenizer.mask_token_id is None:
|
||||
tokenizer.add_special_tokens({"mask_token": "[MASK]"})
|
||||
|
||||
load_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
|
||||
model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path, cache_dir=cfg.cache_dir, dtype=load_dtype)
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
if load_dtype == torch.float32:
|
||||
model.to(dtype=torch.float32)
|
||||
|
||||
mask_token_id = int(tokenizer.mask_token_id)
|
||||
|
||||
if cfg.use_dummy_data:
|
||||
dataset = RandomTokenDataset(
|
||||
num_samples=cfg.num_dummy_samples,
|
||||
seq_len=cfg.max_length,
|
||||
vocab_size=len(tokenizer),
|
||||
pad_token_id=int(tokenizer.pad_token_id),
|
||||
)
|
||||
train_dataloader = DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
batch_size=cfg.per_device_train_batch_size,
|
||||
drop_last=True,
|
||||
)
|
||||
else:
|
||||
raw_datasets = load_dataset(cfg.dataset_name, cfg.dataset_config_name, cache_dir=cfg.cache_dir)
|
||||
if "train" not in raw_datasets:
|
||||
raise ValueError(f"Dataset {cfg.dataset_name} has no 'train' split.")
|
||||
|
||||
with accelerator.main_process_first():
|
||||
tokenized = raw_datasets["train"].map(
|
||||
lambda ex: tokenize_fn(ex, tokenizer, cfg.text_column, cfg.max_length),
|
||||
batched=True,
|
||||
remove_columns=raw_datasets["train"].column_names,
|
||||
desc="Tokenizing",
|
||||
)
|
||||
|
||||
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False, return_tensors="pt")
|
||||
train_dataloader = DataLoader(
|
||||
tokenized, shuffle=True, collate_fn=collator, batch_size=cfg.per_device_train_batch_size, drop_last=True
|
||||
)
|
||||
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)
|
||||
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / cfg.gradient_accumulation_steps)
|
||||
num_train_epochs = math.ceil(cfg.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
name=cfg.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=cfg.lr_warmup_steps,
|
||||
num_training_steps=cfg.max_train_steps,
|
||||
)
|
||||
|
||||
model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
model, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
noise_scheduler = BlockRefinementScheduler(block_length=cfg.block_length)
|
||||
|
||||
global_step = 0
|
||||
model.train()
|
||||
|
||||
for _epoch in range(num_train_epochs):
|
||||
for batch in train_dataloader:
|
||||
with accelerator.accumulate(model):
|
||||
input_ids = batch["input_ids"]
|
||||
attention_mask = batch.get("attention_mask", torch.ones_like(input_ids))
|
||||
|
||||
gen = torch.Generator(device=input_ids.device).manual_seed(cfg.seed + global_step)
|
||||
noisy, noisy_rev, masked, masked_rev = noise_scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=cfg.prompt_length,
|
||||
block_length=cfg.block_length,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
position_ids = (
|
||||
torch.arange(input_ids.shape[1], device=input_ids.device).unsqueeze(0).expand_as(input_ids)
|
||||
)
|
||||
|
||||
logits = model(input_ids=noisy, attention_mask=attention_mask, position_ids=position_ids).logits
|
||||
logits_rev = model(
|
||||
input_ids=noisy_rev, attention_mask=attention_mask, position_ids=position_ids
|
||||
).logits
|
||||
|
||||
logits = logits.clone()
|
||||
logits[..., mask_token_id] = torch.finfo(logits.dtype).min
|
||||
logits_rev = logits_rev.clone()
|
||||
logits_rev[..., mask_token_id] = torch.finfo(logits_rev.dtype).min
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
masked = masked & valid
|
||||
masked_rev = masked_rev & valid
|
||||
|
||||
labels = input_ids.clone()
|
||||
labels[~masked] = -100
|
||||
labels_rev = input_ids.clone()
|
||||
labels_rev[~masked_rev] = -100
|
||||
|
||||
weights = masked.to(dtype=logits.dtype)
|
||||
weights_rev = masked_rev.to(dtype=logits.dtype)
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits,
|
||||
labels,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights,
|
||||
)
|
||||
loss_rev, loss_sft_rev, loss_conf_rev = compute_confidence_aware_loss(
|
||||
logits_rev,
|
||||
labels_rev,
|
||||
lambda_conf=cfg.lambda_conf,
|
||||
temperature=cfg.conf_temperature,
|
||||
per_token_weights=weights_rev,
|
||||
)
|
||||
|
||||
total_loss = loss + loss_rev
|
||||
accelerator.backward(total_loss)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
global_step += 1
|
||||
|
||||
if global_step % cfg.logging_steps == 0 and accelerator.is_main_process:
|
||||
logger.info(
|
||||
"step=%d loss=%.4f sft=%.4f conf=%.4f lr=%.6g",
|
||||
global_step,
|
||||
total_loss.item(),
|
||||
(loss_sft + loss_sft_rev).item(),
|
||||
(loss_conf + loss_conf_rev).item(),
|
||||
lr_scheduler.get_last_lr()[0],
|
||||
)
|
||||
print(
|
||||
f"step={global_step} loss={total_loss.item():.4f} "
|
||||
f"sft={(loss_sft + loss_sft_rev).item():.4f} "
|
||||
f"conf={(loss_conf + loss_conf_rev).item():.4f} "
|
||||
f"lr={lr_scheduler.get_last_lr()[0]:.6g}"
|
||||
)
|
||||
|
||||
if cfg.checkpointing_steps > 0 and global_step % cfg.checkpointing_steps == 0:
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
save_dir = os.path.join(cfg.output_dir, f"checkpoint-{global_step}")
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(save_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(save_dir)
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
if global_step >= cfg.max_train_steps:
|
||||
break
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
final_dir = os.path.join(cfg.output_dir, "final")
|
||||
os.makedirs(final_dir, exist_ok=True)
|
||||
accelerator.unwrap_model(model).save_pretrained(final_dir, save_function=accelerator.save)
|
||||
tokenizer.save_pretrained(final_dir)
|
||||
|
||||
logger.info("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -344,8 +344,6 @@ else:
|
||||
_import_structure["schedulers"].extend(
|
||||
[
|
||||
"AmusedScheduler",
|
||||
"BlockRefinementScheduler",
|
||||
"BlockRefinementSchedulerOutput",
|
||||
"CMStochasticIterativeScheduler",
|
||||
"CogVideoXDDIMScheduler",
|
||||
"CogVideoXDPMScheduler",
|
||||
@@ -582,8 +580,6 @@ else:
|
||||
"LDMTextToImagePipeline",
|
||||
"LEditsPPPipelineStableDiffusion",
|
||||
"LEditsPPPipelineStableDiffusionXL",
|
||||
"LLaDA2Pipeline",
|
||||
"LLaDA2PipelineOutput",
|
||||
"LongCatImageEditPipeline",
|
||||
"LongCatImagePipeline",
|
||||
"LTX2ConditionPipeline",
|
||||
@@ -1128,8 +1124,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .quantizers import DiffusersQuantizer
|
||||
from .schedulers import (
|
||||
AmusedScheduler,
|
||||
BlockRefinementScheduler,
|
||||
BlockRefinementSchedulerOutput,
|
||||
CMStochasticIterativeScheduler,
|
||||
CogVideoXDDIMScheduler,
|
||||
CogVideoXDPMScheduler,
|
||||
@@ -1345,8 +1339,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LDMTextToImagePipeline,
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
LLaDA2Pipeline,
|
||||
LLaDA2PipelineOutput,
|
||||
LongCatImageEditPipeline,
|
||||
LongCatImagePipeline,
|
||||
LTX2ConditionPipeline,
|
||||
|
||||
@@ -285,7 +285,6 @@ else:
|
||||
]
|
||||
)
|
||||
_import_structure["latte"] = ["LattePipeline"]
|
||||
_import_structure["llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
_import_structure["ltx"] = [
|
||||
"LTXPipeline",
|
||||
"LTXImageToVideoPipeline",
|
||||
@@ -729,7 +728,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LEditsPPPipelineStableDiffusion,
|
||||
LEditsPPPipelineStableDiffusionXL,
|
||||
)
|
||||
from .llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
from .longcat_image import LongCatImageEditPipeline, LongCatImagePipeline
|
||||
from .ltx import (
|
||||
LTXConditionPipeline,
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_import_structure = {}
|
||||
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_llada2"] = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_llada2 import LLaDA2Pipeline, LLaDA2PipelineOutput
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
@@ -1,491 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
||||
from ...schedulers import BlockRefinementScheduler
|
||||
from ...utils import BaseOutput, logging, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```python
|
||||
>>> import torch
|
||||
>>> from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
>>> from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
>>> model_id = "inclusionAI/LLaDA2.1-mini"
|
||||
>>> model = AutoModelForCausalLM.from_pretrained(
|
||||
... model_id, trust_remote_code=True, dtype=torch.bfloat16, device_map="auto"
|
||||
... )
|
||||
>>> tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
|
||||
>>> scheduler = BlockRefinementScheduler()
|
||||
|
||||
>>> pipe = LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
>>> output = pipe(prompt="What is the meaning of life?", gen_length=256)
|
||||
>>> print(output.texts[0])
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLaDA2PipelineOutput(BaseOutput):
|
||||
sequences: torch.LongTensor
|
||||
texts: list[str] | None = None
|
||||
|
||||
|
||||
class LLaDA2Pipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for LLaDA2-style discrete diffusion text generation via block-wise iterative refinement.
|
||||
|
||||
This pipeline maintains a template sequence filled with a `mask_token_id` and refines it in blocks. In each
|
||||
refinement step, it samples candidate tokens for the active block and commits a subset based on confidence.
|
||||
|
||||
The model is expected to accept an attention mask and `position_ids`, and to return logits of shape `[batch, seq,
|
||||
vocab_size]`.
|
||||
"""
|
||||
|
||||
model: Any
|
||||
scheduler: BlockRefinementScheduler
|
||||
tokenizer: Any
|
||||
|
||||
_callback_tensor_inputs = ["block_x", "x0", "x0_p", "transfer_index", "confidence", "active_block"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Any,
|
||||
scheduler: BlockRefinementScheduler,
|
||||
tokenizer: Any | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
self.eos_token_id = getattr(self.tokenizer, "eos_token_id", None) if self.tokenizer is not None else None
|
||||
self.mask_token_id = getattr(self.tokenizer, "mask_token_id", None) if self.tokenizer is not None else None
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
|
||||
# --- Prompt encoding ---
|
||||
|
||||
def _prepare_input_ids(
|
||||
self,
|
||||
*,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
use_chat_template: bool,
|
||||
add_generation_prompt: bool,
|
||||
chat_template_kwargs: dict[str, Any] | None,
|
||||
) -> torch.LongTensor:
|
||||
"""Convert prompt/messages/input_ids to a [batch, seq] LongTensor."""
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim == 1:
|
||||
input_ids = input_ids.unsqueeze(0)
|
||||
if input_ids.ndim != 2:
|
||||
raise ValueError(f"`input_ids` must be 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
return input_ids
|
||||
|
||||
if self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
if messages is not None and prompt is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if messages is None and prompt is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
|
||||
chat_template_kwargs = chat_template_kwargs or {}
|
||||
|
||||
if messages is not None:
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
if use_chat_template and getattr(self.tokenizer, "chat_template", None):
|
||||
if isinstance(prompt, list):
|
||||
raise ValueError("`prompt` must be a string when `use_chat_template=True`.")
|
||||
encoded = self.tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}],
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
tokenize=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
**chat_template_kwargs,
|
||||
)
|
||||
return encoded["input_ids"]
|
||||
|
||||
encoded = self.tokenizer(prompt, return_tensors="pt", padding=isinstance(prompt, list))
|
||||
return encoded["input_ids"]
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: str | list[str] | None,
|
||||
messages: list[dict[str, str]] | None,
|
||||
input_ids: torch.LongTensor | None,
|
||||
gen_length: int,
|
||||
block_length: int,
|
||||
num_inference_steps: int,
|
||||
minimal_topk: int,
|
||||
threshold: float,
|
||||
sampling_method: str,
|
||||
output_type: str,
|
||||
callback_on_step_end: Callable | PipelineCallback | MultiPipelineCallbacks | None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None,
|
||||
):
|
||||
# Input source validation
|
||||
if prompt is None and messages is None and input_ids is None:
|
||||
raise ValueError("Provide one of `prompt`, `messages`, or `input_ids`.")
|
||||
if prompt is not None and messages is not None:
|
||||
raise ValueError("Provide either `prompt` or `messages`, not both.")
|
||||
if input_ids is not None:
|
||||
if input_ids.ndim not in (1, 2):
|
||||
raise ValueError(f"`input_ids` must be 1D or 2D, got shape {tuple(input_ids.shape)}.")
|
||||
if input_ids.dtype != torch.long:
|
||||
raise ValueError(f"`input_ids` must be int64 token IDs, got dtype={input_ids.dtype}.")
|
||||
if prompt is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
if messages is not None and input_ids is None and self.tokenizer is None:
|
||||
raise ValueError("Tokenizer is required when `input_ids` is not provided.")
|
||||
|
||||
# Generation parameter validation
|
||||
if gen_length <= 0:
|
||||
raise ValueError(f"`gen_length` must be > 0, got {gen_length}.")
|
||||
if block_length <= 0:
|
||||
raise ValueError(f"`block_length` must be > 0, got {block_length}.")
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
if minimal_topk <= 0:
|
||||
raise ValueError(f"`minimal_topk` must be > 0, got {minimal_topk}.")
|
||||
if not (0.0 <= threshold <= 1.0) and not (threshold > 1.0):
|
||||
raise ValueError(f"`threshold` must be in [0, 1] (or > 1 to force top-k commits), got {threshold}.")
|
||||
if sampling_method not in {"auto", "greedy", "multinomial"}:
|
||||
raise ValueError(
|
||||
f"`sampling_method` must be one of {{'auto','greedy','multinomial'}}, got {sampling_method!r}."
|
||||
)
|
||||
if output_type not in {"seq", "text"}:
|
||||
raise ValueError(f"`output_type` must be 'seq' or 'text', got {output_type!r}.")
|
||||
|
||||
# Callback validation
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found "
|
||||
f"{[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: str | list[str] | None = None,
|
||||
messages: list[dict[str, str]] | None = None,
|
||||
input_ids: torch.LongTensor | None = None,
|
||||
use_chat_template: bool = True,
|
||||
add_generation_prompt: bool = True,
|
||||
gen_length: int = 2048,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "multinomial",
|
||||
threshold: float = 0.7,
|
||||
editing_threshold: float | None = 0.5,
|
||||
max_post_steps: int = 16,
|
||||
minimal_topk: int = 1,
|
||||
eos_early_stop: bool = True,
|
||||
eos_token_id: int | None = None,
|
||||
mask_token_id: int | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
output_type: str = "text",
|
||||
return_dict: bool = True,
|
||||
callback_on_step_end: Callable[[int, int, dict], None]
|
||||
| PipelineCallback
|
||||
| MultiPipelineCallbacks
|
||||
| None = None,
|
||||
callback_on_step_end_tensor_inputs: list[str] | None = None,
|
||||
) -> LLaDA2PipelineOutput | tuple[torch.LongTensor, list[str] | None]:
|
||||
"""
|
||||
Generate text with block-wise refinement.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
Prompt text. When `use_chat_template` is `True` (default) and a tokenizer with a chat template is
|
||||
available, the prompt is wrapped in a chat message before tokenization.
|
||||
messages (`List[Dict[str, str]]`, *optional*):
|
||||
Chat messages to encode (e.g. `[{"role": "user", "content": "Hello"}]`). Takes precedence over `prompt`
|
||||
when provided. Requires a tokenizer with `apply_chat_template`.
|
||||
input_ids (`torch.LongTensor`, *optional*):
|
||||
Pre-tokenized input IDs. Takes precedence over `prompt` and `messages`.
|
||||
use_chat_template (`bool`, defaults to `True`):
|
||||
Whether to wrap the prompt in a chat template.
|
||||
add_generation_prompt (`bool`, defaults to `True`):
|
||||
Whether to add the generation prompt when using chat templates.
|
||||
gen_length (`int`):
|
||||
Number of tokens to generate.
|
||||
block_length (`int`):
|
||||
Block size for refinement.
|
||||
num_inference_steps (`int`):
|
||||
Number of refinement steps per block.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`):
|
||||
Confidence threshold for committing tokens.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing already-committed (non-mask) tokens. When positive, after all mask
|
||||
tokens in a block are resolved, the pipeline continues refining: if the model predicts a different
|
||||
token with confidence above this threshold, the existing token is replaced. Set to `None`, `0.0`, or a
|
||||
negative value to disable editing. Defaults to `0.5`.
|
||||
max_post_steps (`int`):
|
||||
Maximum number of additional refinement iterations after all mask tokens in a block are resolved. Only
|
||||
used when `editing_threshold` is enabled. Defaults to `16`.
|
||||
minimal_topk (`int`):
|
||||
Minimum number of tokens to commit per step.
|
||||
eos_early_stop (`bool`):
|
||||
Whether to stop after committing EOS in a block.
|
||||
eos_token_id (`int`, *optional*):
|
||||
EOS token ID to use for early stopping.
|
||||
mask_token_id (`int`, *optional*):
|
||||
Mask token ID to use for the template.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
output_type (`str`, defaults to `"text"`):
|
||||
Output format. `"text"` decodes sequences into strings (requires a tokenizer). `"seq"` returns raw
|
||||
token ID sequences only.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`LLaDA2PipelineOutput`] instead of a tuple.
|
||||
callback_on_step_end (`Callable` or `PipelineCallback`, *optional*):
|
||||
Callback executed after each refinement step with signature `callback_on_step_end(self, step: int,
|
||||
timestep: int, callback_kwargs: Dict)`.
|
||||
callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
|
||||
Tensor keys to pass to the callback. Allowed keys: `block_x`, `x0`, `x0_p`, `transfer_index`,
|
||||
`confidence`, `active_block`.
|
||||
|
||||
Examples:
|
||||
"""
|
||||
# 1. Check inputs early
|
||||
if callback_on_step_end is not None and isinstance(
|
||||
callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)
|
||||
):
|
||||
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
||||
if callback_on_step_end_tensor_inputs is None:
|
||||
callback_on_step_end_tensor_inputs = ["block_x"]
|
||||
|
||||
self.check_inputs(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
gen_length=gen_length,
|
||||
block_length=block_length,
|
||||
num_inference_steps=num_inference_steps,
|
||||
minimal_topk=minimal_topk,
|
||||
threshold=threshold,
|
||||
sampling_method=sampling_method,
|
||||
output_type=output_type,
|
||||
callback_on_step_end=callback_on_step_end,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
# 2. Prepare input IDs from prompt/messages/input_ids
|
||||
prompt_ids = self._prepare_input_ids(
|
||||
prompt=prompt,
|
||||
messages=messages,
|
||||
input_ids=input_ids,
|
||||
use_chat_template=use_chat_template,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
if prompt_ids.ndim == 1:
|
||||
prompt_ids = prompt_ids.unsqueeze(0)
|
||||
prompt_ids = prompt_ids.to(device=device)
|
||||
batch_size, prompt_length = prompt_ids.shape
|
||||
|
||||
if eos_token_id is None:
|
||||
eos_token_id = self.eos_token_id
|
||||
if mask_token_id is None:
|
||||
mask_token_id = self.mask_token_id
|
||||
if mask_token_id is None:
|
||||
raise ValueError("`mask_token_id` must be provided (or available on the tokenizer).")
|
||||
|
||||
num_inference_steps = min(num_inference_steps, gen_length // minimal_topk)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
|
||||
# 3. Build attention mask and position IDs
|
||||
num_blocks = (prompt_length + gen_length + block_length - 1) // block_length
|
||||
total_length = num_blocks * block_length
|
||||
|
||||
# 2D attention mask (no padding) — the model handles backend-specific conversion internally.
|
||||
attn_mask = torch.ones((batch_size, total_length), device=device, dtype=torch.long)
|
||||
|
||||
position_ids = torch.arange(total_length, device=device, dtype=torch.long).unsqueeze(0).expand(batch_size, -1)
|
||||
|
||||
# 4. Prepare latents (fully masked sequence)
|
||||
x = torch.full((batch_size, total_length), mask_token_id, device=device, dtype=torch.long)
|
||||
if prompt_length > 0:
|
||||
x[:, :prompt_length] = prompt_ids
|
||||
|
||||
prefill_blocks = prompt_length // block_length
|
||||
self._num_timesteps = num_inference_steps * max(num_blocks - prefill_blocks, 0)
|
||||
|
||||
finished = torch.zeros((batch_size,), device=device, dtype=torch.bool)
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
global_step = 0
|
||||
|
||||
# 5. Block-wise refinement loop
|
||||
block_progress_bar_config = getattr(self, "_progress_bar_config", {}).copy()
|
||||
block_progress_bar_config["position"] = 0
|
||||
block_progress_bar_config["desc"] = "Blocks"
|
||||
for num_block in tqdm(range(prefill_blocks, num_blocks), **block_progress_bar_config):
|
||||
current_window_end = (num_block + 1) * block_length
|
||||
block_x = x[:, :current_window_end]
|
||||
block_attn_mask = attn_mask[:, :current_window_end]
|
||||
block_position_ids = position_ids[:, :current_window_end]
|
||||
|
||||
# Identify which positions in the block are prompt (non-editable).
|
||||
block_start_pos = num_block * block_length
|
||||
prompt_mask_in_block = torch.zeros(block_length, device=device, dtype=torch.bool)
|
||||
if block_start_pos < prompt_length:
|
||||
prompt_end_in_block = min(prompt_length - block_start_pos, block_length)
|
||||
prompt_mask_in_block[:prompt_end_in_block] = True
|
||||
|
||||
post_steps = 0
|
||||
step_idx = 0
|
||||
should_continue = True
|
||||
self.set_progress_bar_config(position=1, leave=False, desc=f"Block {num_block} Inference Steps")
|
||||
progress_bar = self.progress_bar(total=num_inference_steps)
|
||||
|
||||
while should_continue:
|
||||
block_tokens = block_x[:, -block_length:]
|
||||
masks_remaining = (block_tokens == mask_token_id).any()
|
||||
|
||||
if not masks_remaining:
|
||||
post_steps += 1
|
||||
|
||||
logits = self.model(block_x, attention_mask=block_attn_mask, position_ids=block_position_ids).logits
|
||||
block_logits = logits[:, -block_length:, :]
|
||||
|
||||
scheduler_output = self.scheduler.step(
|
||||
model_output=block_logits,
|
||||
timestep=step_idx,
|
||||
sample=block_tokens,
|
||||
mask_token_id=mask_token_id,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
sampling_method=sampling_method,
|
||||
threshold=threshold,
|
||||
editing_threshold=editing_threshold,
|
||||
minimal_topk=minimal_topk,
|
||||
prompt_mask=prompt_mask_in_block,
|
||||
generator=generator,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
transfer_index = scheduler_output.transfer_index
|
||||
editing_transfer_index = scheduler_output.editing_transfer_index
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
|
||||
if final_transfer.any():
|
||||
block_x[:, -block_length:] = scheduler_output.prev_sample
|
||||
|
||||
if eos_early_stop and eos_token_id is not None:
|
||||
finished = self.scheduler.check_eos_finished(
|
||||
cur_x=block_x,
|
||||
sampled_tokens=scheduler_output.sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_token_id,
|
||||
mask_token_id=mask_token_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_outputs = callback_on_step_end(self, global_step, step_idx, callback_kwargs)
|
||||
block_x = callback_outputs.pop("block_x", block_x)
|
||||
|
||||
global_step += 1
|
||||
if masks_remaining:
|
||||
step_idx += 1
|
||||
progress_bar.update(1)
|
||||
|
||||
should_continue = self.scheduler.check_block_should_continue(
|
||||
step_idx=step_idx,
|
||||
masks_remaining=masks_remaining,
|
||||
editing_enabled=editing_enabled,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
post_steps=post_steps,
|
||||
max_post_steps=max_post_steps,
|
||||
finished=finished,
|
||||
)
|
||||
|
||||
progress_bar.close()
|
||||
x[:, :current_window_end] = block_x
|
||||
if eos_early_stop and finished.all():
|
||||
break
|
||||
|
||||
# 6. Post-process output
|
||||
generated = x[:, : prompt_length + gen_length]
|
||||
sequences = generated[:, prompt_length:]
|
||||
if eos_token_id is not None and batch_size == 1:
|
||||
eos_positions = (sequences[0] == eos_token_id).nonzero(as_tuple=True)[0]
|
||||
if len(eos_positions) > 0:
|
||||
sequences = sequences[:, : int(eos_positions[0].item()) + 1]
|
||||
|
||||
texts = None
|
||||
if output_type == "text" and self.tokenizer is not None:
|
||||
texts = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
||||
|
||||
if not return_dict:
|
||||
return sequences.to(device=device), texts
|
||||
return LLaDA2PipelineOutput(sequences=sequences.to(device=device), texts=texts)
|
||||
|
||||
|
||||
__all__ = ["LLaDA2Pipeline", "LLaDA2PipelineOutput"]
|
||||
@@ -40,7 +40,6 @@ except OptionalDependencyNotAvailable:
|
||||
else:
|
||||
_import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
|
||||
_import_structure["scheduling_amused"] = ["AmusedScheduler"]
|
||||
_import_structure["scheduling_block_refinement"] = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
_import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
|
||||
_import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
|
||||
_import_structure["scheduling_ddim"] = ["DDIMScheduler"]
|
||||
@@ -146,7 +145,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
else:
|
||||
from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
|
||||
from .scheduling_amused import AmusedScheduler
|
||||
from .scheduling_block_refinement import BlockRefinementScheduler, BlockRefinementSchedulerOutput
|
||||
from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
|
||||
from .scheduling_consistency_models import CMStochasticIterativeScheduler
|
||||
from .scheduling_ddim import DDIMScheduler
|
||||
|
||||
@@ -1,460 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockRefinementSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for block refinement scheduling.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Updated block tokens after the current refinement step.
|
||||
transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were committed (mask-filling).
|
||||
editing_transfer_index (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Boolean mask indicating which tokens were edited (non-mask replacement).
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Sampled token IDs from the model logits.
|
||||
sampled_probs (`torch.Tensor` of shape `(batch_size, block_length)`):
|
||||
Probabilities of the sampled tokens.
|
||||
"""
|
||||
|
||||
prev_sample: torch.LongTensor
|
||||
transfer_index: torch.BoolTensor
|
||||
editing_transfer_index: torch.BoolTensor
|
||||
sampled_tokens: torch.LongTensor
|
||||
sampled_probs: torch.Tensor
|
||||
|
||||
|
||||
class BlockRefinementScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
Scheduler for block-wise iterative refinement (commit-by-confidence).
|
||||
|
||||
At each step, the scheduler samples candidate tokens from model logits and commits those with the highest
|
||||
confidence. The number of tokens to commit per step is determined by evenly distributing the block length across
|
||||
the number of refinement steps.
|
||||
|
||||
Optionally supports editing: after all mask tokens are resolved, tokens can be replaced if the model predicts a
|
||||
different token with confidence above a positive `editing_threshold` (`None`, `0.0`, or negative disables editing).
|
||||
"""
|
||||
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
block_length: int = 32,
|
||||
num_inference_steps: int = 32,
|
||||
threshold: float = 0.95,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int = 1,
|
||||
):
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, dtype=torch.long)
|
||||
self._transfer_schedule: torch.LongTensor | None = None
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: str | torch.device | None = None) -> None:
|
||||
if num_inference_steps <= 0:
|
||||
raise ValueError(f"`num_inference_steps` must be > 0, got {num_inference_steps}.")
|
||||
self.num_inference_steps = num_inference_steps
|
||||
self.timesteps = torch.arange(self.num_inference_steps - 1, -1, -1, device=device, dtype=torch.long)
|
||||
self._transfer_schedule = self.get_num_transfer_tokens(self.config.block_length, self.num_inference_steps).to(
|
||||
device=device if device is not None else "cpu"
|
||||
)
|
||||
|
||||
def get_num_transfer_tokens(self, block_length: int, num_inference_steps: int) -> torch.LongTensor:
|
||||
"""Evenly distribute `block_length` token commits across `num_inference_steps` steps."""
|
||||
if num_inference_steps <= 0:
|
||||
return torch.zeros((0,), dtype=torch.long)
|
||||
base = block_length // num_inference_steps
|
||||
remainder = block_length % num_inference_steps
|
||||
out = torch.full((num_inference_steps,), base, dtype=torch.long)
|
||||
out[:remainder] += 1
|
||||
return out
|
||||
|
||||
# --- SAR sampling utilities ---
|
||||
|
||||
@staticmethod
|
||||
def _top_p_filtering(logits: torch.Tensor, top_p: float | None) -> torch.Tensor:
|
||||
"""Nucleus (top-p) logit filtering."""
|
||||
if top_p is None or top_p >= 1.0:
|
||||
return logits
|
||||
if not (0.0 < top_p <= 1.0):
|
||||
raise ValueError(f"`top_p` must be in (0, 1], got {top_p}.")
|
||||
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
|
||||
sorted_probs = torch.softmax(sorted_logits, dim=-1)
|
||||
cumulative_probs = sorted_probs.cumsum(dim=-1)
|
||||
|
||||
sorted_indices_to_remove = cumulative_probs > float(top_p)
|
||||
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
sorted_logits = sorted_logits.masked_fill(sorted_indices_to_remove, torch.finfo(sorted_logits.dtype).min)
|
||||
filtered = logits.scatter(-1, sorted_indices, sorted_logits)
|
||||
return filtered
|
||||
|
||||
@staticmethod
|
||||
def _top_k_filtering(logits: torch.Tensor, top_k: int | None) -> torch.Tensor:
|
||||
"""Top-k logit filtering."""
|
||||
if top_k is None or top_k <= 0:
|
||||
return logits
|
||||
if top_k >= logits.shape[-1]:
|
||||
return logits
|
||||
values, _ = torch.topk(logits, k=top_k, dim=-1)
|
||||
min_keep = values[..., -1, None]
|
||||
return logits.masked_fill(logits < min_keep, torch.finfo(logits.dtype).min)
|
||||
|
||||
@staticmethod
|
||||
def _sample_from_logits(
|
||||
logits: torch.Tensor,
|
||||
*,
|
||||
temperature: float,
|
||||
top_k: int | None,
|
||||
top_p: float | None,
|
||||
generator: torch.Generator | None,
|
||||
use_multinomial: bool,
|
||||
) -> tuple[torch.LongTensor, torch.Tensor]:
|
||||
"""Sample tokens from logits with temperature scaling, top-k, and top-p."""
|
||||
if temperature < 0:
|
||||
raise ValueError(f"`temperature` must be >= 0, got {temperature}.")
|
||||
|
||||
vocab_size = logits.shape[-1]
|
||||
flat_logits = logits.reshape(-1, vocab_size)
|
||||
|
||||
if temperature == 0.0 or not use_multinomial:
|
||||
probs = torch.softmax(flat_logits.float(), dim=-1)
|
||||
token = flat_logits.argmax(dim=-1, keepdim=True)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
scaled = flat_logits
|
||||
if temperature != 1.0:
|
||||
scaled = flat_logits / temperature
|
||||
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(scaled, top_k=top_k)
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(filtered, top_p=top_p)
|
||||
|
||||
probs = torch.softmax(filtered.float(), dim=-1)
|
||||
token = torch.multinomial(probs, num_samples=1, generator=generator)
|
||||
token_prob = torch.gather(probs, -1, token)
|
||||
|
||||
return token.view(*logits.shape[:-1]), token_prob.view(*logits.shape[:-1])
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.Tensor,
|
||||
timestep: int | torch.Tensor,
|
||||
sample: torch.LongTensor,
|
||||
*,
|
||||
mask_token_id: int,
|
||||
temperature: float = 0.0,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
sampling_method: str = "auto",
|
||||
threshold: float | None = None,
|
||||
editing_threshold: float | None = None,
|
||||
minimal_topk: int | None = None,
|
||||
prompt_mask: torch.BoolTensor | None = None,
|
||||
generator: torch.Generator | None = None,
|
||||
return_dict: bool = True,
|
||||
) -> (
|
||||
BlockRefinementSchedulerOutput
|
||||
| tuple[torch.LongTensor, torch.BoolTensor, torch.BoolTensor, torch.LongTensor, torch.Tensor]
|
||||
):
|
||||
"""
|
||||
Perform a single refinement step: sample from logits, commit confident tokens, and optionally edit existing
|
||||
ones.
|
||||
|
||||
Args:
|
||||
model_output (`torch.Tensor` of shape `(batch_size, block_length, vocab_size)`):
|
||||
Raw logits from the model for the current block.
|
||||
timestep (`int` or `torch.Tensor`):
|
||||
Current step index within the block's refinement schedule.
|
||||
sample (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Current block token IDs (contains mask tokens for uncommitted positions).
|
||||
mask_token_id (`int`):
|
||||
Token ID used for masked positions.
|
||||
temperature (`float`):
|
||||
Sampling temperature.
|
||||
top_p (`float`, *optional*):
|
||||
Nucleus sampling cutoff.
|
||||
top_k (`int`, *optional*):
|
||||
Top-k sampling cutoff.
|
||||
sampling_method (`str`):
|
||||
Sampling method (`auto`, `greedy`, `multinomial`).
|
||||
threshold (`float`, *optional*):
|
||||
Confidence threshold for committing tokens. Defaults to config value.
|
||||
editing_threshold (`float`, *optional*):
|
||||
Confidence threshold for editing non-mask tokens; must be positive to enable editing. Defaults to
|
||||
config value.
|
||||
minimal_topk (`int`, *optional*):
|
||||
Minimum tokens to commit per step. Defaults to config value.
|
||||
prompt_mask (`torch.BoolTensor`, *optional*):
|
||||
Boolean mask of shape `(block_length,)` where `True` marks prompt (non-editable) positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for sampling.
|
||||
return_dict (`bool`):
|
||||
Whether to return a `BlockRefinementSchedulerOutput` or a tuple.
|
||||
"""
|
||||
if threshold is None:
|
||||
threshold = float(self.config.threshold)
|
||||
if editing_threshold is None:
|
||||
editing_threshold = self.config.editing_threshold
|
||||
if minimal_topk is None:
|
||||
minimal_topk = self.config.minimal_topk
|
||||
|
||||
# Sample from logits
|
||||
use_multinomial = sampling_method == "multinomial" or (sampling_method == "auto" and temperature != 0.0)
|
||||
sampled_tokens, sampled_probs = self._sample_from_logits(
|
||||
model_output,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
generator=generator,
|
||||
use_multinomial=use_multinomial,
|
||||
)
|
||||
|
||||
batch_size, block_length = sample.shape
|
||||
active_block = sample == mask_token_id
|
||||
masks_remaining = active_block.any()
|
||||
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
step_index = int(timestep.item())
|
||||
else:
|
||||
step_index = int(timestep)
|
||||
|
||||
# --- Mask-filling transfer ---
|
||||
transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if masks_remaining and self._transfer_schedule is not None:
|
||||
clamped_step = min(step_index, len(self._transfer_schedule) - 1)
|
||||
num_to_transfer = int(self._transfer_schedule[clamped_step].item())
|
||||
|
||||
confidence = torch.where(
|
||||
active_block,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
|
||||
for b in range(batch_size):
|
||||
high_conf = confidence[b] > threshold
|
||||
if high_conf.sum().item() >= num_to_transfer:
|
||||
transfer_index[b] = high_conf
|
||||
else:
|
||||
k = min(num_to_transfer, int(active_block[b].sum().item()))
|
||||
if k > 0:
|
||||
_, idx = torch.topk(confidence[b], k=k)
|
||||
transfer_index[b, idx] = True
|
||||
|
||||
# --- Editing transfer (non-mask, non-prompt positions) ---
|
||||
editing_enabled = editing_threshold is not None and editing_threshold > 0.0
|
||||
editing_transfer_index = torch.zeros_like(sampled_tokens, dtype=torch.bool)
|
||||
if editing_enabled:
|
||||
if prompt_mask is None:
|
||||
prompt_mask = torch.zeros(block_length, device=sample.device, dtype=torch.bool)
|
||||
editable = (~active_block) & (~prompt_mask.unsqueeze(0))
|
||||
editing_conf = torch.where(
|
||||
editable,
|
||||
sampled_probs.to(dtype=torch.float32),
|
||||
torch.full_like(sampled_probs, -torch.inf, dtype=torch.float32),
|
||||
)
|
||||
high_conf_edit = editing_conf > float(editing_threshold)
|
||||
token_changed = sampled_tokens != sample
|
||||
editing_transfer_index = high_conf_edit & token_changed & editable
|
||||
|
||||
# Apply transfers
|
||||
final_transfer = transfer_index | editing_transfer_index
|
||||
prev_sample = sample.clone()
|
||||
if final_transfer.any():
|
||||
prev_sample[final_transfer] = sampled_tokens[final_transfer]
|
||||
|
||||
if not return_dict:
|
||||
return prev_sample, transfer_index, editing_transfer_index, sampled_tokens, sampled_probs
|
||||
return BlockRefinementSchedulerOutput(
|
||||
prev_sample=prev_sample,
|
||||
transfer_index=transfer_index,
|
||||
editing_transfer_index=editing_transfer_index,
|
||||
sampled_tokens=sampled_tokens,
|
||||
sampled_probs=sampled_probs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def check_eos_finished(
|
||||
cur_x: torch.LongTensor,
|
||||
sampled_tokens: torch.LongTensor,
|
||||
final_transfer: torch.BoolTensor,
|
||||
finished: torch.BoolTensor,
|
||||
eos_token_id: int,
|
||||
mask_token_id: int,
|
||||
prompt_length: int,
|
||||
) -> torch.BoolTensor:
|
||||
"""
|
||||
Update per-batch finished flags when EOS tokens are committed.
|
||||
|
||||
Args:
|
||||
cur_x (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Current full sequence including all blocks up to the current window.
|
||||
sampled_tokens (`torch.LongTensor` of shape `(batch_size, block_length)`):
|
||||
Tokens sampled by the scheduler in this step.
|
||||
final_transfer (`torch.BoolTensor` of shape `(batch_size, block_length)`):
|
||||
Combined mask of committed and edited positions.
|
||||
finished (`torch.BoolTensor` of shape `(batch_size,)`):
|
||||
Current per-batch finished flags.
|
||||
eos_token_id (`int`):
|
||||
EOS token ID.
|
||||
mask_token_id (`int`):
|
||||
Mask token ID.
|
||||
prompt_length (`int`):
|
||||
Number of prompt tokens at the start of the sequence.
|
||||
|
||||
Returns:
|
||||
`torch.BoolTensor`: Updated finished flags.
|
||||
"""
|
||||
batch_size = cur_x.shape[0]
|
||||
for b in range(batch_size):
|
||||
if finished[b]:
|
||||
continue
|
||||
eos_in_commits = (sampled_tokens[b][final_transfer[b]] == eos_token_id).any().item()
|
||||
if not eos_in_commits:
|
||||
continue
|
||||
eos_pos = (cur_x[b] == eos_token_id).nonzero(as_tuple=True)
|
||||
if len(eos_pos[0]) == 0:
|
||||
continue
|
||||
eos_pos = int(eos_pos[0][0].item())
|
||||
if prompt_length >= eos_pos:
|
||||
continue
|
||||
if (cur_x[b, prompt_length:eos_pos] != mask_token_id).all().item():
|
||||
finished[b] = True
|
||||
return finished
|
||||
|
||||
def check_block_should_continue(
|
||||
self,
|
||||
step_idx: int,
|
||||
masks_remaining: bool,
|
||||
editing_enabled: bool,
|
||||
editing_transfer_index: torch.BoolTensor,
|
||||
post_steps: int,
|
||||
max_post_steps: int,
|
||||
finished: torch.BoolTensor,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether the inner refinement loop should continue for the current block.
|
||||
|
||||
Args:
|
||||
step_idx (`int`):
|
||||
Current refinement step index within this block.
|
||||
masks_remaining (`bool`):
|
||||
Whether any mask tokens remain in the block.
|
||||
editing_enabled (`bool`):
|
||||
Whether editing mode is active.
|
||||
editing_transfer_index (`torch.BoolTensor`):
|
||||
Which tokens were edited in this step.
|
||||
post_steps (`int`):
|
||||
Number of post-mask editing steps taken so far.
|
||||
max_post_steps (`int`):
|
||||
Maximum allowed post-mask editing steps.
|
||||
finished (`torch.BoolTensor`):
|
||||
Per-batch finished flags (from EOS detection).
|
||||
|
||||
Returns:
|
||||
`bool`: `True` if refinement should continue, `False` to break.
|
||||
"""
|
||||
if finished.all():
|
||||
return False
|
||||
if not masks_remaining and not editing_enabled:
|
||||
return False
|
||||
if not masks_remaining and not editing_transfer_index.any():
|
||||
return False
|
||||
if masks_remaining and step_idx >= self.num_inference_steps:
|
||||
return False
|
||||
if not masks_remaining and post_steps > max_post_steps:
|
||||
return False
|
||||
return True
|
||||
|
||||
def add_noise(
|
||||
self,
|
||||
original_samples: torch.LongTensor,
|
||||
attention_mask: torch.LongTensor,
|
||||
*,
|
||||
prompt_length: int,
|
||||
block_length: int,
|
||||
mask_token_id: int,
|
||||
generator: torch.Generator | None = None,
|
||||
) -> tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]:
|
||||
"""
|
||||
Apply the forward (noising) process for semi-autoregressive block masking.
|
||||
|
||||
For each block after the prompt, a random fraction of valid (non-padding) tokens are replaced with
|
||||
`mask_token_id`. Two complementary views are returned: `noisy` and `noisy_rev`, where the masked positions in
|
||||
one are the unmasked positions in the other.
|
||||
|
||||
Args:
|
||||
original_samples (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Clean token IDs.
|
||||
attention_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
|
||||
Padding mask (1 for valid, 0 for padding).
|
||||
prompt_length (`int`):
|
||||
Number of leading prompt tokens to keep unmasked.
|
||||
block_length (`int`):
|
||||
Block size for masking.
|
||||
mask_token_id (`int`):
|
||||
Token ID to use for masked positions.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
RNG for reproducibility.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor, torch.BoolTensor]`:
|
||||
`(noisy, noisy_rev, masked, masked_rev)` — the two complementary noisy sequences and their
|
||||
corresponding boolean masks.
|
||||
"""
|
||||
batch_size, seq_len = original_samples.shape
|
||||
device = original_samples.device
|
||||
|
||||
noisy = original_samples.clone()
|
||||
noisy_rev = original_samples.clone()
|
||||
masked = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
masked_rev = torch.zeros_like(original_samples, dtype=torch.bool)
|
||||
|
||||
valid = attention_mask.to(dtype=torch.bool)
|
||||
for block_start in range(prompt_length, seq_len, block_length):
|
||||
block_end = min(seq_len, block_start + block_length)
|
||||
seg_len = block_end - block_start
|
||||
if seg_len <= 0:
|
||||
continue
|
||||
|
||||
p_mask = torch.rand((batch_size, 1), device=device, generator=generator)
|
||||
seg = torch.rand((batch_size, seg_len), device=device, generator=generator) < p_mask
|
||||
seg = seg & valid[:, block_start:block_end]
|
||||
seg_rev = (~seg) & valid[:, block_start:block_end]
|
||||
|
||||
masked[:, block_start:block_end] = seg
|
||||
masked_rev[:, block_start:block_end] = seg_rev
|
||||
|
||||
noisy = torch.where(masked, torch.full_like(noisy, mask_token_id), noisy)
|
||||
noisy_rev = torch.where(masked_rev, torch.full_like(noisy_rev, mask_token_id), noisy_rev)
|
||||
return noisy, noisy_rev, masked, masked_rev
|
||||
|
||||
|
||||
__all__ = ["BlockRefinementScheduler", "BlockRefinementSchedulerOutput"]
|
||||
@@ -11,7 +11,6 @@ from typing import Any, Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
if getattr(torch, "distributed", None) is not None:
|
||||
@@ -110,92 +109,6 @@ def compute_snr(noise_scheduler, timesteps):
|
||||
return snr
|
||||
|
||||
|
||||
def compute_confidence_aware_loss(
|
||||
logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
*,
|
||||
lambda_conf: float = 0.0,
|
||||
temperature: float = 1.0,
|
||||
per_token_weights: torch.Tensor | None = None,
|
||||
ignore_index: int = -100,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Computes a confidence-aware training loss for token classification-style heads.
|
||||
|
||||
This loss combines:
|
||||
- `loss_sft`: standard supervised cross-entropy on all non-ignored labels.
|
||||
- `loss_conf`: an entropy penalty applied only on tokens that are already predicted correctly.
|
||||
|
||||
Args:
|
||||
logits (`torch.Tensor`): Logits of shape `(..., vocab_size)`.
|
||||
labels (`torch.Tensor`): Labels of shape `(...)`, matching `logits.shape[:-1]`. Values set to `ignore_index`
|
||||
are excluded from both losses.
|
||||
lambda_conf (`float`, *optional*, defaults to `0.0`): Weight for the confidence term.
|
||||
temperature (`float`, *optional*, defaults to `1.0`): Temperature used for the entropy term only. Lower values
|
||||
sharpen the distribution and change the strength of the confidence gradients.
|
||||
per_token_weights (`torch.Tensor`, *optional*): Optional weights of shape `(...)` to reweight both losses per
|
||||
token (e.g. schedule-aware weights). Tokens with weight `0` contribute nothing.
|
||||
ignore_index (`int`, *optional*, defaults to `-100`): Ignore index for labels.
|
||||
|
||||
Returns:
|
||||
`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`: `(loss, loss_sft, loss_conf)`.
|
||||
"""
|
||||
if logits.ndim < 2:
|
||||
raise ValueError(f"`logits` must have at least 2 dims, got shape {tuple(logits.shape)}.")
|
||||
if labels.shape != logits.shape[:-1]:
|
||||
raise ValueError(
|
||||
f"`labels` shape must match `logits.shape[:-1]`, got labels={tuple(labels.shape)} logits={tuple(logits.shape)}."
|
||||
)
|
||||
if temperature <= 0:
|
||||
raise ValueError(f"`temperature` must be > 0, got {temperature}.")
|
||||
|
||||
valid = labels.ne(ignore_index)
|
||||
if per_token_weights is None:
|
||||
weights = torch.ones_like(labels, dtype=logits.dtype)
|
||||
else:
|
||||
if per_token_weights.shape != labels.shape:
|
||||
raise ValueError(
|
||||
f"`per_token_weights` shape must match `labels` shape, got {tuple(per_token_weights.shape)} != {tuple(labels.shape)}."
|
||||
)
|
||||
weights = per_token_weights.to(dtype=logits.dtype)
|
||||
|
||||
# Supervised CE (optionally weighted).
|
||||
vocab_size = logits.shape[-1]
|
||||
per_token_nll = F.cross_entropy(
|
||||
logits.reshape(-1, vocab_size),
|
||||
labels.reshape(-1),
|
||||
reduction="none",
|
||||
ignore_index=ignore_index,
|
||||
).reshape_as(labels)
|
||||
|
||||
denom_sft = (weights * valid.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_sft = (per_token_nll * weights * valid.to(per_token_nll.dtype)).sum() / denom_sft
|
||||
|
||||
# Confidence loss: penalize entropy only where prediction is already correct.
|
||||
if lambda_conf == 0.0:
|
||||
loss_conf = torch.zeros((), device=logits.device, dtype=loss_sft.dtype)
|
||||
return loss_sft, loss_sft, loss_conf
|
||||
|
||||
with torch.no_grad():
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = valid & pred.eq(labels)
|
||||
|
||||
scaled_logits = logits.float()
|
||||
if temperature != 1.0:
|
||||
scaled_logits = scaled_logits / float(temperature)
|
||||
|
||||
probs = torch.softmax(scaled_logits, dim=-1)
|
||||
eps = torch.finfo(probs.dtype).tiny
|
||||
log_probs = torch.log(probs.clamp_min(eps))
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
|
||||
denom_conf = (weights * correct.to(weights.dtype)).sum().clamp_min(1)
|
||||
loss_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / denom_conf
|
||||
|
||||
loss = loss_sft + float(lambda_conf) * loss_conf
|
||||
return loss, loss_sft, loss_conf
|
||||
|
||||
|
||||
def resolve_interpolation_mode(interpolation_type: str):
|
||||
"""
|
||||
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
||||
|
||||
@@ -2518,36 +2518,6 @@ class AmusedScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class BlockRefinementSchedulerOutput(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class CMStochasticIterativeScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -2222,36 +2222,6 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2Pipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LLaDA2PipelineOutput(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class LongCatImageEditPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
@@ -177,6 +177,11 @@ class QuantizationTesterMixin:
|
||||
model_quantized.to(torch_device)
|
||||
|
||||
inputs = self.get_dummy_inputs()
|
||||
model_dtype = next(model_quantized.parameters()).dtype
|
||||
inputs = {
|
||||
k: v.to(dtype=model_dtype) if torch.is_tensor(v) and torch.is_floating_point(v) else v
|
||||
for k, v in inputs.items()
|
||||
}
|
||||
output = model_quantized(**inputs, return_dict=False)[0]
|
||||
|
||||
assert output is not None, "Model output is None"
|
||||
@@ -930,6 +935,7 @@ class TorchAoTesterMixin(TorchAoConfigMixin, QuantizationTesterMixin):
|
||||
"""Test that device_map='auto' works correctly with quantization."""
|
||||
self._test_quantization_device_map(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
@pytest.mark.xfail(reason="dequantize is not implemented in torchao")
|
||||
def test_torchao_dequantize(self):
|
||||
"""Test that dequantize() works correctly."""
|
||||
self._test_dequantize(TorchAoConfigMixin.TORCHAO_QUANT_TYPES["int8wo"])
|
||||
|
||||
@@ -13,31 +13,23 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaTransformer2DModel
|
||||
from diffusers.models.attention_processor import FluxIPAdapterJointAttnProcessor2_0
|
||||
from diffusers.models.embeddings import ImageProjection
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
IPAdapterTesterMixin,
|
||||
LoraHotSwappingForModelTesterMixin,
|
||||
LoraTesterMixin,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
def create_bria_ip_adapter_state_dict(model):
|
||||
# "ip_adapter" (cross-attention weights)
|
||||
ip_cross_attn_state_dict = {}
|
||||
key_id = 0
|
||||
|
||||
@@ -58,8 +50,11 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
f"{key_id}.to_v_ip.bias": sd["to_v_ip.0.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
key_id += 1
|
||||
|
||||
# "image_proj" (ImageProjection layer weights)
|
||||
|
||||
image_projection = ImageProjection(
|
||||
cross_attention_dim=model.config["joint_attention_dim"],
|
||||
image_embed_dim=model.config["pooled_projection_dim"],
|
||||
@@ -78,36 +73,53 @@ def create_bria_ip_adapter_state_dict(model) -> dict[str, dict[str, Any]]:
|
||||
)
|
||||
|
||||
del sd
|
||||
return {"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}
|
||||
ip_state_dict = {}
|
||||
ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict})
|
||||
return ip_state_dict
|
||||
|
||||
|
||||
class BriaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return BriaTransformer2DModel
|
||||
class BriaTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.7, 0.7]
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 4)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 4,
|
||||
"num_layers": 1,
|
||||
@@ -119,35 +131,11 @@ class BriaTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
height = width = 4
|
||||
sequence_length = 48
|
||||
embedding_dim = 32
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
),
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
|
||||
|
||||
class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
def test_deprecated_inputs_img_txt_ids_3d(self):
|
||||
init_dict = self.get_init_dict()
|
||||
inputs_dict = self.get_dummy_inputs()
|
||||
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
@@ -155,6 +143,7 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_1 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
# update inputs_dict with txt_ids and img_ids as 3d tensors (deprecated)
|
||||
text_ids_3d = inputs_dict["txt_ids"].unsqueeze(0)
|
||||
image_ids_3d = inputs_dict["img_ids"].unsqueeze(0)
|
||||
|
||||
@@ -167,59 +156,26 @@ class TestBriaTransformer(BriaTransformerTesterConfig, ModelTesterMixin):
|
||||
with torch.no_grad():
|
||||
output_2 = model(**inputs_dict).to_tuple()[0]
|
||||
|
||||
assert output_1.shape == output_2.shape
|
||||
assert torch.allclose(output_1, output_2, atol=1e-5), (
|
||||
"output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) "
|
||||
"are not equal as them as 2d inputs"
|
||||
self.assertEqual(output_1.shape, output_2.shape)
|
||||
self.assertTrue(
|
||||
torch.allclose(output_1, output_2, atol=1e-5),
|
||||
msg="output with deprecated inputs (img_ids and txt_ids as 3d torch tensors) are not equal as them as 2d inputs",
|
||||
)
|
||||
|
||||
|
||||
class TestBriaTransformerTraining(BriaTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestBriaTransformerCompile(BriaTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
class BriaTransformerCompileTests(TorchCompileTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
|
||||
class TestBriaTransformerIPAdapter(BriaTransformerTesterConfig, IPAdapterTesterMixin):
|
||||
@property
|
||||
def ip_adapter_processor_cls(self):
|
||||
return FluxIPAdapterJointAttnProcessor2_0
|
||||
class BriaTransformerLoRAHotSwapTests(LoraHotSwappingForModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaTransformer2DModel
|
||||
|
||||
def modify_inputs_for_ip_adapter(self, model, inputs_dict):
|
||||
torch.manual_seed(0)
|
||||
cross_attention_dim = getattr(model.config, "joint_attention_dim", 32)
|
||||
image_embeds = torch.randn(1, 1, cross_attention_dim).to(torch_device)
|
||||
inputs_dict.update({"joint_attention_kwargs": {"ip_adapter_image_embeds": image_embeds}})
|
||||
return inputs_dict
|
||||
|
||||
def create_ip_adapter_state_dict(self, model: Any) -> dict[str, dict[str, Any]]:
|
||||
return create_bria_ip_adapter_state_dict(model)
|
||||
|
||||
|
||||
class TestBriaTransformerLoRA(BriaTransformerTesterConfig, LoraTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestBriaTransformerLoRAHotSwap(BriaTransformerTesterConfig, LoraHotSwappingForModelTesterMixin):
|
||||
@property
|
||||
def different_shapes_for_compilation(self):
|
||||
return [(4, 4), (4, 8), (8, 8)]
|
||||
|
||||
def get_dummy_inputs(self, height: int = 4, width: int = 4) -> dict[str, torch.Tensor]:
|
||||
batch_size = 1
|
||||
num_latent_channels = 4
|
||||
num_image_channels = 3
|
||||
sequence_length = 24
|
||||
embedding_dim = 32
|
||||
|
||||
return {
|
||||
"hidden_states": randn_tensor((batch_size, height * width, num_latent_channels), device=torch_device),
|
||||
"encoder_hidden_states": randn_tensor((batch_size, sequence_length, embedding_dim), device=torch_device),
|
||||
"img_ids": randn_tensor((height * width, num_image_channels), device=torch_device),
|
||||
"txt_ids": randn_tensor((sequence_length, num_image_channels), device=torch_device),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
}
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
return BriaTransformerTests().prepare_init_args_and_inputs_for_common()
|
||||
|
||||
@@ -13,50 +13,62 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BriaFiboTransformer2DModel
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..testing_utils import (
|
||||
BaseModelTesterConfig,
|
||||
ModelTesterMixin,
|
||||
TorchCompileTesterMixin,
|
||||
TrainingTesterMixin,
|
||||
)
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
|
||||
@property
|
||||
def model_class(self):
|
||||
return BriaFiboTransformer2DModel
|
||||
class BriaFiboTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = BriaFiboTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
# We override the items here because the transformer under consideration is small.
|
||||
model_split_percents = [0.8, 0.7, 0.7]
|
||||
|
||||
# Skip setting testing with default: AttnProcessor
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def main_input_name(self) -> str:
|
||||
return "hidden_states"
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
|
||||
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
|
||||
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": image_ids,
|
||||
"txt_ids": text_ids,
|
||||
"timestep": timestep,
|
||||
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
|
||||
}
|
||||
|
||||
@property
|
||||
def model_split_percents(self) -> list:
|
||||
return [0.8, 0.7, 0.7]
|
||||
|
||||
@property
|
||||
def output_shape(self) -> tuple:
|
||||
return (256, 48)
|
||||
|
||||
@property
|
||||
def input_shape(self) -> tuple:
|
||||
def input_shape(self):
|
||||
return (16, 16)
|
||||
|
||||
@property
|
||||
def generator(self):
|
||||
return torch.Generator("cpu").manual_seed(0)
|
||||
def output_shape(self):
|
||||
return (256, 48)
|
||||
|
||||
def get_init_dict(self) -> dict:
|
||||
return {
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"patch_size": 1,
|
||||
"in_channels": 48,
|
||||
"num_layers": 1,
|
||||
@@ -69,41 +81,9 @@ class BriaFiboTransformerTesterConfig(BaseModelTesterConfig):
|
||||
"axes_dims_rope": [0, 4, 4],
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, batch_size: int = 1) -> dict[str, torch.Tensor]:
|
||||
num_latent_channels = 48
|
||||
num_image_channels = 3
|
||||
height = width = 16
|
||||
sequence_length = 32
|
||||
embedding_dim = 64
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
encoder_hidden_states = randn_tensor(
|
||||
(batch_size, sequence_length, embedding_dim), generator=self.generator, device=torch_device
|
||||
)
|
||||
return {
|
||||
"hidden_states": randn_tensor(
|
||||
(batch_size, height * width, num_latent_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"img_ids": randn_tensor(
|
||||
(height * width, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"txt_ids": randn_tensor(
|
||||
(sequence_length, num_image_channels), generator=self.generator, device=torch_device
|
||||
),
|
||||
"timestep": torch.tensor([1.0]).to(torch_device).expand(batch_size),
|
||||
"text_encoder_layers": [encoder_hidden_states[:, :, :32], encoder_hidden_states[:, :, :32]],
|
||||
}
|
||||
|
||||
|
||||
class TestBriaFiboTransformer(BriaFiboTransformerTesterConfig, ModelTesterMixin):
|
||||
pass
|
||||
|
||||
|
||||
class TestBriaFiboTransformerTraining(BriaFiboTransformerTesterConfig, TrainingTesterMixin):
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"BriaFiboTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
class TestBriaFiboTransformerCompile(BriaFiboTransformerTesterConfig, TorchCompileTesterMixin):
|
||||
pass
|
||||
|
||||
@@ -18,7 +18,7 @@ import unittest
|
||||
import torch
|
||||
|
||||
from diffusers import DDIMScheduler, DDPMScheduler, UNet2DModel
|
||||
from diffusers.training_utils import compute_confidence_aware_loss, set_seed
|
||||
from diffusers.training_utils import set_seed
|
||||
|
||||
from ..testing_utils import slow
|
||||
|
||||
@@ -85,47 +85,3 @@ class TrainingTests(unittest.TestCase):
|
||||
|
||||
self.assertTrue(torch.allclose(ddpm_noisy_images, ddim_noisy_images, atol=1e-5))
|
||||
self.assertTrue(torch.allclose(ddpm_noise_pred, ddim_noise_pred, atol=1e-5))
|
||||
|
||||
def test_confidence_aware_loss(self):
|
||||
logits = torch.tensor([[[5.0, 0.0], [0.0, 5.0]]])
|
||||
labels = torch.tensor([[0, 0]])
|
||||
weights = torch.tensor([[1.0, 2.0]])
|
||||
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=0.0, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss, loss_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, torch.zeros_like(loss_conf)))
|
||||
|
||||
lambda_conf = 0.25
|
||||
loss, loss_sft, loss_conf = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, per_token_weights=weights
|
||||
)
|
||||
|
||||
# Manual expected values for the small 2-class case.
|
||||
per_token_nll = torch.nn.functional.cross_entropy(logits.view(-1, 2), labels.view(-1), reduction="none").view(
|
||||
1, 2
|
||||
)
|
||||
expected_sft = (per_token_nll * weights).sum() / weights.sum()
|
||||
|
||||
pred = logits.argmax(dim=-1)
|
||||
correct = pred.eq(labels)
|
||||
log_probs = torch.log_softmax(logits.float(), dim=-1)
|
||||
probs = log_probs.exp()
|
||||
entropy = -(probs * log_probs).sum(dim=-1).to(dtype=logits.dtype)
|
||||
expected_conf = (entropy * weights * correct.to(entropy.dtype)).sum() / (
|
||||
weights * correct.to(weights.dtype)
|
||||
).sum().clamp_min(1)
|
||||
|
||||
expected = expected_sft + lambda_conf * expected_conf
|
||||
self.assertTrue(torch.allclose(loss_sft, expected_sft))
|
||||
self.assertTrue(torch.allclose(loss_conf, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss, expected))
|
||||
|
||||
# Temperature affects only the confidence term.
|
||||
loss_t, loss_sft_t, loss_conf_t = compute_confidence_aware_loss(
|
||||
logits, labels, lambda_conf=lambda_conf, temperature=0.5, per_token_weights=weights
|
||||
)
|
||||
self.assertTrue(torch.allclose(loss_sft_t, expected_sft))
|
||||
self.assertFalse(torch.allclose(loss_conf_t, expected_conf))
|
||||
self.assertTrue(torch.allclose(loss_t, loss_sft_t + lambda_conf * loss_conf_t))
|
||||
|
||||
@@ -1,245 +0,0 @@
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler, LLaDA2Pipeline
|
||||
|
||||
|
||||
class _DummyModelOutput:
|
||||
def __init__(self, logits):
|
||||
self.logits = logits
|
||||
|
||||
|
||||
class _DummyCausalLM(torch.nn.Module):
|
||||
def __init__(self, vocab_size: int):
|
||||
super().__init__()
|
||||
self.vocab_size = int(vocab_size)
|
||||
self.register_buffer("_device_anchor", torch.empty(0))
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return torch.float32
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._device_anchor.device
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, position_ids=None, **kwargs):
|
||||
batch_size, seq_len = input_ids.shape
|
||||
logits = torch.zeros((batch_size, seq_len, self.vocab_size), device=input_ids.device, dtype=torch.float32)
|
||||
|
||||
# Make confidence vary with token position so top-k commits are deterministic.
|
||||
positions = torch.arange(seq_len, device=input_ids.device, dtype=torch.float32).view(1, seq_len, 1)
|
||||
token_ids = (torch.arange(seq_len, device=input_ids.device) % (self.vocab_size - 2)).view(1, seq_len, 1)
|
||||
logits.scatter_(2, token_ids.expand(batch_size, -1, -1), 1.0 + positions.expand(batch_size, -1, -1) * 0.1)
|
||||
return _DummyModelOutput(logits=logits)
|
||||
|
||||
|
||||
def _make_pipeline(tokenizer=None):
|
||||
model = _DummyCausalLM(vocab_size=32)
|
||||
scheduler = BlockRefinementScheduler()
|
||||
return LLaDA2Pipeline(model=model, scheduler=scheduler, tokenizer=tokenizer)
|
||||
|
||||
|
||||
class LLaDA2PipelineTest(unittest.TestCase):
|
||||
def test_pipeline_runs(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8], [1, 2, 3, 4]], dtype=torch.long)
|
||||
out = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=24,
|
||||
block_length=8,
|
||||
num_inference_steps=8,
|
||||
temperature=0.0,
|
||||
threshold=2.0, # force top-k commits
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
eos_token_id=None,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertEqual(out.sequences.shape, (2, 24))
|
||||
self.assertFalse((out.sequences == 31).any().item())
|
||||
|
||||
def test_pipeline_return_tuple(self):
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
input_ids = torch.tensor([[5, 6, 7, 8]], dtype=torch.long)
|
||||
sequences, texts = pipe(
|
||||
input_ids=input_ids,
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertEqual(sequences.shape, (1, 16))
|
||||
self.assertIsNone(texts)
|
||||
|
||||
def test_output_type_seq(self):
|
||||
"""output_type='seq' should return sequences but no texts."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="seq",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertEqual(out.sequences.shape, (1, 16))
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_without_tokenizer(self):
|
||||
"""output_type='text' without a tokenizer should return texts=None."""
|
||||
pipe = _make_pipeline(tokenizer=None).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
mask_token_id=31,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNone(out.texts)
|
||||
|
||||
def test_output_type_text_with_tokenizer(self):
|
||||
"""output_type='text' with a tokenizer should return decoded texts."""
|
||||
tok = type(
|
||||
"Tok",
|
||||
(),
|
||||
{
|
||||
"eos_token_id": None,
|
||||
"mask_token_id": 31,
|
||||
"batch_decode": lambda self, seqs, **kw: [f"decoded_{len(s)}" for s in seqs],
|
||||
},
|
||||
)()
|
||||
pipe = _make_pipeline(tokenizer=tok).to("cpu")
|
||||
|
||||
out = pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
temperature=0.0,
|
||||
threshold=2.0,
|
||||
minimal_topk=1,
|
||||
eos_early_stop=False,
|
||||
output_type="text",
|
||||
)
|
||||
|
||||
self.assertIsNotNone(out.sequences)
|
||||
self.assertIsNotNone(out.texts)
|
||||
self.assertEqual(len(out.texts), 1)
|
||||
self.assertTrue(out.texts[0].startswith("decoded_"))
|
||||
|
||||
def test_output_type_invalid_raises(self):
|
||||
"""Invalid output_type should raise ValueError."""
|
||||
pipe = _make_pipeline().to("cpu")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
pipe(
|
||||
input_ids=torch.tensor([[5, 6, 7, 8]], dtype=torch.long),
|
||||
use_chat_template=False,
|
||||
gen_length=16,
|
||||
block_length=8,
|
||||
num_inference_steps=4,
|
||||
mask_token_id=31,
|
||||
output_type="invalid",
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_from_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([[1, 2, 3]], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertTrue(torch.equal(result, ids))
|
||||
|
||||
def test_prepare_input_ids_from_1d_tensor(self):
|
||||
pipe = _make_pipeline()
|
||||
ids = torch.tensor([1, 2, 3], dtype=torch.long)
|
||||
result = pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=ids,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
self.assertEqual(result.shape, (1, 3))
|
||||
|
||||
def test_prepare_input_ids_no_tokenizer_raises(self):
|
||||
pipe = _make_pipeline(tokenizer=None)
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_both_prompt_and_messages_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
# Manually set tokenizer to a simple object so _prepare_input_ids doesn't short-circuit
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt="hello",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
def test_prepare_input_ids_neither_raises(self):
|
||||
pipe = _make_pipeline()
|
||||
pipe.tokenizer = type("Tok", (), {"eos_token_id": None, "mask_token_id": None})()
|
||||
with self.assertRaises(ValueError):
|
||||
pipe._prepare_input_ids(
|
||||
prompt=None,
|
||||
messages=None,
|
||||
input_ids=None,
|
||||
use_chat_template=False,
|
||||
add_generation_prompt=False,
|
||||
chat_template_kwargs=None,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1534,18 +1534,14 @@ class PipelineTesterMixin:
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
pipe.to("cpu")
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
self.assertTrue(all(device == "cpu" for device in model_devices))
|
||||
|
||||
output_cpu = pipe(**self.get_dummy_inputs("cpu"))[0]
|
||||
self.assertTrue(np.isnan(output_cpu).sum() == 0)
|
||||
|
||||
pipe.to(torch_device)
|
||||
model_devices = [
|
||||
component.device.type for component in components.values() if getattr(component, "device", None)
|
||||
]
|
||||
model_devices = [component.device.type for component in components.values() if hasattr(component, "device")]
|
||||
self.assertTrue(all(device == torch_device for device in model_devices))
|
||||
|
||||
output_device = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -1556,11 +1552,11 @@ class PipelineTesterMixin:
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if getattr(component, "dtype", None)]
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
|
||||
@@ -1,470 +0,0 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import BlockRefinementScheduler
|
||||
|
||||
|
||||
class BlockRefinementSchedulerTest(unittest.TestCase):
|
||||
def get_scheduler(self, **kwargs):
|
||||
config = {
|
||||
"block_length": 32,
|
||||
"num_inference_steps": 8,
|
||||
"threshold": 0.95,
|
||||
"editing_threshold": None,
|
||||
"minimal_topk": 1,
|
||||
}
|
||||
config.update(kwargs)
|
||||
return BlockRefinementScheduler(**config)
|
||||
|
||||
def _make_logits_from_probs(self, target_probs: torch.Tensor, vocab_size: int = 100) -> torch.Tensor:
|
||||
"""Create logits where softmax of the target token has approximately the given probability."""
|
||||
batch_size, block_length = target_probs.shape
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
# Set token 0 as the "predicted" token with a logit proportional to desired probability
|
||||
for b in range(batch_size):
|
||||
for t in range(block_length):
|
||||
p = target_probs[b, t].item()
|
||||
if p > 0:
|
||||
logits[b, t, t % (vocab_size - 1)] = 10.0 * p
|
||||
return logits
|
||||
|
||||
def test_set_timesteps(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
self.assertEqual(scheduler.num_inference_steps, 8)
|
||||
self.assertEqual(len(scheduler.timesteps), 8)
|
||||
self.assertEqual(scheduler.timesteps[0].item(), 7)
|
||||
self.assertEqual(scheduler.timesteps[-1].item(), 0)
|
||||
|
||||
def test_set_timesteps_invalid(self):
|
||||
scheduler = self.get_scheduler()
|
||||
with self.assertRaises(ValueError):
|
||||
scheduler.set_timesteps(0)
|
||||
|
||||
def test_get_num_transfer_tokens_even(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=32, num_inference_steps=8)
|
||||
self.assertEqual(schedule.sum().item(), 32)
|
||||
self.assertEqual(len(schedule), 8)
|
||||
self.assertTrue((schedule == 4).all().item())
|
||||
|
||||
def test_get_num_transfer_tokens_remainder(self):
|
||||
scheduler = self.get_scheduler()
|
||||
schedule = scheduler.get_num_transfer_tokens(block_length=10, num_inference_steps=3)
|
||||
self.assertEqual(schedule.sum().item(), 10)
|
||||
self.assertEqual(len(schedule), 3)
|
||||
self.assertEqual(schedule[0].item(), 4)
|
||||
self.assertEqual(schedule[1].item(), 3)
|
||||
self.assertEqual(schedule[2].item(), 3)
|
||||
|
||||
def test_transfer_schedule_created_on_set_timesteps(self):
|
||||
scheduler = self.get_scheduler(block_length=16)
|
||||
scheduler.set_timesteps(4)
|
||||
self.assertIsNotNone(scheduler._transfer_schedule)
|
||||
self.assertEqual(scheduler._transfer_schedule.sum().item(), 16)
|
||||
|
||||
def test_save_load_config_round_trip(self):
|
||||
scheduler = self.get_scheduler(block_length=64, threshold=0.8, editing_threshold=0.5, minimal_topk=2)
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
scheduler.save_config(tmpdir)
|
||||
loaded = BlockRefinementScheduler.from_pretrained(tmpdir)
|
||||
|
||||
self.assertEqual(loaded.config.block_length, 64)
|
||||
self.assertEqual(loaded.config.threshold, 0.8)
|
||||
self.assertEqual(loaded.config.editing_threshold, 0.5)
|
||||
self.assertEqual(loaded.config.minimal_topk, 2)
|
||||
|
||||
def test_from_config(self):
|
||||
scheduler = self.get_scheduler(block_length=16, threshold=0.7)
|
||||
new_scheduler = BlockRefinementScheduler.from_config(scheduler.config)
|
||||
self.assertEqual(new_scheduler.config.block_length, 16)
|
||||
self.assertEqual(new_scheduler.config.threshold, 0.7)
|
||||
|
||||
def test_step_commits_tokens(self):
|
||||
"""Verify that step() commits mask tokens based on confidence."""
|
||||
scheduler = self.get_scheduler(block_length=8)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, block_length, vocab_size = 1, 8, 32
|
||||
mask_id = 31
|
||||
|
||||
sample = torch.full((batch_size, block_length), mask_id, dtype=torch.long)
|
||||
# Create logits where confidence decreases with position
|
||||
logits = torch.zeros(batch_size, block_length, vocab_size)
|
||||
for i in range(block_length):
|
||||
logits[0, i, i] = 10.0 - i # decreasing confidence
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
threshold=0.95,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# With 8 tokens and 2 steps, first step should commit 4 tokens
|
||||
committed = out.transfer_index[0].sum().item()
|
||||
self.assertEqual(committed, 4)
|
||||
|
||||
def test_step_no_editing_by_default(self):
|
||||
"""Without editing_threshold, no non-mask tokens should be changed."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 10.0 # predict token 15 for all positions
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=None,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index.any().item())
|
||||
self.assertFalse(out.transfer_index[0, 0].item())
|
||||
self.assertFalse(out.transfer_index[0, 1].item())
|
||||
|
||||
def test_step_editing_replaces_tokens(self):
|
||||
"""With editing_threshold, non-mask tokens with high confidence and different prediction get replaced."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
# Token 0: predict 50 (different from 10) with very high logit
|
||||
logits[0, 0, 15] = 20.0
|
||||
# Token 1: predict 20 (same as current)
|
||||
logits[0, 1, 20] = 20.0
|
||||
# Mask tokens
|
||||
logits[0, 2, 5] = 5.0
|
||||
logits[0, 3, 6] = 5.0
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Token 0 should be edited (different prediction, high confidence)
|
||||
self.assertTrue(out.editing_transfer_index[0, 0].item())
|
||||
# Token 1 should NOT be edited (same prediction)
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_prompt_mask_prevents_editing(self):
|
||||
"""Prompt positions should never be edited even with editing enabled."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.tensor([[10, 20, 31, 31]], dtype=torch.long)
|
||||
logits = torch.zeros(1, 4, vocab_size)
|
||||
logits[0, :, 15] = 20.0
|
||||
prompt_mask = torch.tensor([True, True, False, False])
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
editing_threshold=0.5,
|
||||
prompt_mask=prompt_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertFalse(out.editing_transfer_index[0, 0].item())
|
||||
self.assertFalse(out.editing_transfer_index[0, 1].item())
|
||||
|
||||
def test_step_return_tuple(self):
|
||||
"""Verify tuple output when return_dict=False."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
vocab_size = 32
|
||||
sample = torch.full((1, 4), 31, dtype=torch.long)
|
||||
logits = torch.randn(1, 4, vocab_size)
|
||||
|
||||
result = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=31,
|
||||
temperature=0.0,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
self.assertIsInstance(result, tuple)
|
||||
self.assertEqual(len(result), 5)
|
||||
|
||||
def test_step_batched(self):
|
||||
"""Verify step works with batch_size > 1."""
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
scheduler.set_timesteps(2)
|
||||
|
||||
batch_size, vocab_size = 3, 32
|
||||
mask_id = 31
|
||||
sample = torch.full((batch_size, 4), mask_id, dtype=torch.long)
|
||||
logits = torch.randn(batch_size, 4, vocab_size)
|
||||
|
||||
out = scheduler.step(
|
||||
model_output=logits,
|
||||
timestep=0,
|
||||
sample=sample,
|
||||
mask_token_id=mask_id,
|
||||
temperature=0.0,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(out.prev_sample.shape, (batch_size, 4))
|
||||
self.assertEqual(out.transfer_index.shape, (batch_size, 4))
|
||||
|
||||
def test_check_block_should_continue_finished(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([True, True])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=0,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(2, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_no_masks_no_edits(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=5,
|
||||
masks_remaining=False,
|
||||
editing_enabled=True,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=1,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_block_should_continue_steps_exhausted(self):
|
||||
scheduler = self.get_scheduler()
|
||||
scheduler.set_timesteps(8)
|
||||
finished = torch.tensor([False])
|
||||
result = scheduler.check_block_should_continue(
|
||||
step_idx=8,
|
||||
masks_remaining=True,
|
||||
editing_enabled=False,
|
||||
editing_transfer_index=torch.zeros(1, 32, dtype=torch.bool),
|
||||
post_steps=0,
|
||||
max_post_steps=16,
|
||||
finished=finished,
|
||||
)
|
||||
self.assertFalse(result)
|
||||
|
||||
def test_check_eos_finished_marks_batch(self):
|
||||
"""When EOS is committed and all tokens before it are unmasked, mark batch as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, token, eos, mask, mask]
|
||||
cur_x = torch.tensor([[10, 11, 5, eos_id, mask_id, mask_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0, 0, eos_id]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False, False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_ignores_when_masks_before_eos(self):
|
||||
"""If there are still mask tokens between prompt and EOS, don't mark as finished."""
|
||||
mask_id, eos_id, prompt_length = 99, 2, 2
|
||||
# cur_x: [prompt, prompt, mask, eos] — mask before EOS
|
||||
cur_x = torch.tensor([[10, 11, mask_id, eos_id]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, True]])
|
||||
finished = torch.tensor([False])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=prompt_length,
|
||||
)
|
||||
self.assertFalse(finished[0].item())
|
||||
|
||||
def test_check_eos_finished_already_finished(self):
|
||||
"""Already-finished batches should stay finished."""
|
||||
mask_id, eos_id = 99, 2
|
||||
cur_x = torch.tensor([[10, 11, 5, 6]], dtype=torch.long)
|
||||
sampled_tokens = torch.tensor([[0, 0]], dtype=torch.long)
|
||||
final_transfer = torch.tensor([[False, False]])
|
||||
finished = torch.tensor([True])
|
||||
|
||||
finished = BlockRefinementScheduler.check_eos_finished(
|
||||
cur_x=cur_x,
|
||||
sampled_tokens=sampled_tokens,
|
||||
final_transfer=final_transfer,
|
||||
finished=finished,
|
||||
eos_token_id=eos_id,
|
||||
mask_token_id=mask_id,
|
||||
prompt_length=2,
|
||||
)
|
||||
self.assertTrue(finished[0].item())
|
||||
|
||||
def test_add_noise(self):
|
||||
scheduler = self.get_scheduler(block_length=4)
|
||||
input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]], dtype=torch.long)
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
mask_token_id = 99
|
||||
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
noisy, noisy_rev, masked, masked_rev = scheduler.add_noise(
|
||||
input_ids,
|
||||
attention_mask,
|
||||
prompt_length=2,
|
||||
block_length=4,
|
||||
mask_token_id=mask_token_id,
|
||||
generator=gen,
|
||||
)
|
||||
|
||||
# Prompt positions should never be masked
|
||||
self.assertFalse(masked[0, 0].item())
|
||||
self.assertFalse(masked[0, 1].item())
|
||||
self.assertFalse(masked_rev[0, 0].item())
|
||||
self.assertFalse(masked_rev[0, 1].item())
|
||||
|
||||
# Noisy should have mask_token_id where masked is True
|
||||
self.assertTrue((noisy[masked] == mask_token_id).all().item())
|
||||
self.assertTrue((noisy_rev[masked_rev] == mask_token_id).all().item())
|
||||
|
||||
# masked and masked_rev should be complementary within valid non-prompt positions
|
||||
non_prompt = torch.zeros_like(masked)
|
||||
non_prompt[0, 2:] = True
|
||||
combined = masked | masked_rev
|
||||
self.assertTrue((combined[0, 2:] == non_prompt[0, 2:]).all().item())
|
||||
|
||||
|
||||
class TestTopPFiltering(unittest.TestCase):
|
||||
def test_top_p_filtering(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])
|
||||
filtered = BlockRefinementScheduler._top_p_filtering(logits, top_p=0.5)
|
||||
self.assertTrue((filtered > torch.finfo(filtered.dtype).min).any())
|
||||
self.assertTrue((filtered == torch.finfo(filtered.dtype).min).any())
|
||||
|
||||
def test_top_p_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_p_filtering_one(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_p_filtering(logits, top_p=1.0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestTopKFiltering(unittest.TestCase):
|
||||
def test_top_k_filtering(self):
|
||||
logits = torch.tensor([[1.0, 4.0, 2.0, 3.0]])
|
||||
filtered = BlockRefinementScheduler._top_k_filtering(logits, top_k=2)
|
||||
self.assertAlmostEqual(filtered[0, 1].item(), 4.0)
|
||||
self.assertAlmostEqual(filtered[0, 3].item(), 3.0)
|
||||
self.assertEqual(filtered[0, 0].item(), torch.finfo(filtered.dtype).min)
|
||||
self.assertEqual(filtered[0, 2].item(), torch.finfo(filtered.dtype).min)
|
||||
|
||||
def test_top_k_filtering_none(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=None)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_zero(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=0)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
def test_top_k_filtering_large_k(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
result = BlockRefinementScheduler._top_k_filtering(logits, top_k=100)
|
||||
self.assertTrue(torch.equal(result, logits))
|
||||
|
||||
|
||||
class TestSampleFromLogits(unittest.TestCase):
|
||||
def test_greedy_sampling(self):
|
||||
logits = torch.tensor([[1.0, 5.0, 2.0]])
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
self.assertEqual(tokens.shape, (1,))
|
||||
self.assertEqual(probs.shape, (1,))
|
||||
|
||||
def test_multinomial_sampling(self):
|
||||
logits = torch.tensor([[0.0, 100.0, -100.0]])
|
||||
gen = torch.Generator().manual_seed(42)
|
||||
tokens, probs = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=gen,
|
||||
use_multinomial=True,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 1)
|
||||
|
||||
def test_temperature_scaling(self):
|
||||
logits = torch.tensor([[1.0, 2.0, 3.0]])
|
||||
tokens, _ = BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=0.01,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
self.assertEqual(tokens.item(), 2)
|
||||
|
||||
def test_negative_temperature_raises(self):
|
||||
logits = torch.tensor([[1.0, 2.0]])
|
||||
with self.assertRaises(ValueError):
|
||||
BlockRefinementScheduler._sample_from_logits(
|
||||
logits,
|
||||
temperature=-1.0,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
generator=None,
|
||||
use_multinomial=False,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user